From 20a36bb2658aa7676c66af478cf8cacbe449f955 Mon Sep 17 00:00:00 2001
From: Iria Ayan-Miguez <iria.ayan-miguez_2@ecmwf.int>
Date: Wed, 27 Mar 2024 11:31:13 +0000
Subject: [PATCH] Add a 2D MPL_ALLREDUCE interface

---
 src/fiat/mpl/internal/mpi4to8_m.F90         |  67 +++++-
 src/fiat/mpl/internal/mpl_allreduce_mod.F90 | 241 +++++++++++++++++++-
 2 files changed, 306 insertions(+), 2 deletions(-)

diff --git a/src/fiat/mpl/internal/mpi4to8_m.F90 b/src/fiat/mpl/internal/mpi4to8_m.F90
index 1c2e4d1c..f841373e 100644
--- a/src/fiat/mpl/internal/mpi4to8_m.F90
+++ b/src/fiat/mpl/internal/mpi4to8_m.F90
@@ -26,7 +26,8 @@ MODULE MPI4TO8_M
 
 INTERFACE MPI_ALLREDUCE8
   MODULE PROCEDURE MPI_ALLREDUCE8_R4, MPI_ALLREDUCE8_R8, &
-                   MPI_ALLREDUCE8_I4
+                   MPI_ALLREDUCE8_I4, MPI_ALLREDUCE8_R4_2D, &
+                   MPI_ALLREDUCE8_R8_2D
 END INTERFACE MPI_ALLREDUCE8
 
 INTERFACE MPI_ALLTOALLV8
@@ -179,6 +180,43 @@ SUBROUTINE MPI_ALLREDUCE8_R4(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
 
 END SUBROUTINE MPI_ALLREDUCE8_R4
 
+SUBROUTINE MPI_ALLREDUCE8_R4_2D(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
+                             COMM, IERROR)
+
+  REAL(KIND=JPRM), DIMENSION(:,:), INTENT(IN) :: &
+    SENDDATA(:,:)
+  INTEGER(KIND=JPIM), INTENT(IN) :: &
+    COUNT, DATATYPE, OP, COMM
+  REAL(KIND=JPRM), DIMENSION(:,:), INTENT(OUT) :: &
+    RECVDATA(:.:)
+  INTEGER(KIND=JPIM), INTENT(OUT) :: &
+    IERROR
+
+  REAL(KIND=8), DIMENSION(:,:), ALLOCATABLE :: &
+    SENDDATA8, RECVDATA8
+  INTEGER(KIND=8) :: &
+    COUNT8, DATATYPE8, OP8, COMM8, IERROR8
+
+  ALLOCATE(SENDDATA8(SIZE(SENDDATA)))
+  ALLOCATE(RECVDATA8(SIZE(RECVDATA)))
+
+  SENDDATA8 = SENDDATA
+  COUNT8 = COUNT
+  DATATYPE8 = DATATYPE
+  OP8 = OP
+  COMM8 = COMM
+
+  CALL MPI_ALLREDUCE(SENDDATA8, RECVDATA8, COUNT8, DATATYPE8, OP8, COMM8, IERROR8)
+
+  RECVDATA = RECVDATA8
+  IERROR = IERROR8
+
+  DEALLOCATE(SENDDATA8)
+  DEALLOCATE(RECVDATA8)
+
+END SUBROUTINE MPI_ALLREDUCE8_R4_2D
+
+
 ! ---------------------------------------------------------
 SUBROUTINE MPI_ALLREDUCE8_R8(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
                              COMM, IERROR)
@@ -206,6 +244,33 @@ SUBROUTINE MPI_ALLREDUCE8_R8(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
 
 END SUBROUTINE MPI_ALLREDUCE8_R8
 
+SUBROUTINE MPI_ALLREDUCE8_R8_2D(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
+                             COMM, IERROR)
+
+  REAL(KIND=JPRD), DIMENSION(:,:), INTENT(IN) :: &
+    SENDDATA(:,:)
+  INTEGER(KIND=JPIM), INTENT(IN) :: &
+    COUNT, DATATYPE, OP, COMM
+  REAL(KIND=JPRD), DIMENSION(:,:), INTENT(OUT) :: &
+    RECVDATA(:,:)
+  INTEGER(KIND=JPIM), INTENT(OUT) :: &
+    IERROR
+
+  INTEGER(KIND=8) :: &
+    COUNT8, DATATYPE8, OP8, COMM8, IERROR8
+
+  COUNT8 = COUNT
+  DATATYPE8 = DATATYPE
+  OP8 = OP
+  COMM8 = COMM
+
+  CALL MPI_ALLREDUCE(SENDDATA, RECVDATA, COUNT8, DATATYPE8, OP8, COMM8, IERROR8)
+
+  IERROR = IERROR8
+
+END SUBROUTINE MPI_ALLREDUCE8_R8
+
+
 ! ---------------------------------------------------------
 SUBROUTINE MPI_ALLREDUCE8_I4(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
                               COMM, IERROR)
diff --git a/src/fiat/mpl/internal/mpl_allreduce_mod.F90 b/src/fiat/mpl/internal/mpl_allreduce_mod.F90
index 1e412ea5..adf83640 100644
--- a/src/fiat/mpl/internal/mpl_allreduce_mod.F90
+++ b/src/fiat/mpl/internal/mpl_allreduce_mod.F90
@@ -85,7 +85,8 @@ MODULE MPL_ALLREDUCE_MOD
 MODULE PROCEDURE MPL_ALLREDUCE_REAL8, MPL_ALLREDUCE_REAL4, MPL_ALLREDUCE_INT, &
   MPL_ALLREDUCE_INT8, &
   MPL_ALLREDUCE_REAL8_SCALAR, MPL_ALLREDUCE_REAL4_SCALAR, &
-  MPL_ALLREDUCE_INT_SCALAR, MPL_ALLREDUCE_INT8_SCALAR
+  MPL_ALLREDUCE_INT_SCALAR, MPL_ALLREDUCE_INT8_SCALAR, &
+  MPL_ALLREDUCE_REAL4_2D, MPL_ALLREDUCE_REAL8_2D
 END INTERFACE
 
 PUBLIC MPL_ALLREDUCE
@@ -459,6 +460,135 @@ SUBROUTINE MPL_ALLREDUCE_REAL8(PSENDBUF,CDOPER,LDREPROD, &
 
 END SUBROUTINE MPL_ALLREDUCE_REAL8
 
+SUBROUTINE MPL_ALLREDUCE_REAL8_2D(PSENDBUF,CDOPER,LDREPROD, &
+                            & KCOMM,KERROR,CDSTRING)
+
+
+#ifdef USE_8_BYTE_WORDS
+  USE MPI4TO8, ONLY : &
+    MPI_ALLREDUCE => MPI_ALLREDUCE8
+#endif
+
+REAL(KIND=JPRD),INTENT(INOUT)        :: PSENDBUF(:,:)
+CHARACTER(LEN=*),INTENT(IN)    :: CDOPER
+LOGICAL,INTENT(IN),OPTIONAL :: LDREPROD
+INTEGER(KIND=JPIM),INTENT(IN),OPTIONAL :: KCOMM
+INTEGER(KIND=JPIM),INTENT(OUT),OPTIONAL :: KERROR
+CHARACTER(LEN=*),INTENT(IN),OPTIONAL :: CDSTRING
+REAL(KIND=JPRD)            :: ZRECVBUF(SIZE(PSENDBUF(:,1)),SIZE(PSENDBUF(1,:)))
+INTEGER(KIND=JPIM) ITAG, ICOUNT
+LOGICAL LLREPRODSUM
+INTEGER(KIND=JPIM) :: ISENDCOUNT,ICOMM,IERROR,IOPER
+INTEGER(KIND=JPIM) :: IP2,II,IHALF,JSTAGE,ISEND,IRECV,IMSENT
+INTEGER(KIND=JPIM) :: ISREQ(MPL_NUMPROC)
+INTEGER(KIND=JPIM) :: ITID
+IERROR = 0
+ITID = OML_MY_THREAD()
+LLREPRODSUM = .FALSE.
+
+IF(MPL_NUMPROC < 1) CALL MPL_MESSAGE( &
+  & CDMESSAGE='MPL_ALLREDUCE: MPL NOT INITIALISED ',LDABORT=LLABORT)
+
+IF(CDOPER(1:3) == 'MAX' .OR. CDOPER(1:3) == 'max' ) THEN
+  IOPER = MPI_MAX
+ELSEIF(CDOPER(1:3) == 'MIN' .OR. CDOPER(1:3) == 'min' ) THEN
+  IOPER = MPI_MIN
+ELSEIF(CDOPER(1:3) == 'SUM' .OR. CDOPER(1:3) == 'sum' ) THEN
+  IOPER = MPI_SUM
+  IF (PRESENT(LDREPROD)) THEN
+    LLREPRODSUM = LDREPROD
+  ELSE
+    CALL MPL_MESSAGE(IERROR,&
+     & 'MPL_ALLREDUCE: SUMMATION OPERATOR NOT REPRODUCIBLE IN REAL MODE',&
+     & CDSTRING,LDABORT=LLABORT)
+  ENDIF
+ELSE
+  CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE: ERROR UNKNOWN OPERATOR',&
+   & CDSTRING,LDABORT=LLABORT)
+ENDIF
+
+IF(PRESENT(KCOMM)) THEN
+  ICOMM=KCOMM
+ELSE
+  ICOMM=MPL_COMM_OML(ITID)
+ENDIF
+
+ISENDCOUNT = SIZE(PSENDBUF)
+!#ifndef NAG
+!IF (ISENDCOUNT > 0) THEN
+!  IF( (LOC(PSENDBUF(UBOUND(PSENDBUF,1)))-LOC(PSENDBUF(LBOUND(PSENDBUF,1)))) /= 8_JPIB*(ISENDCOUNT - 1) ) THEN
+!    CALL MPL_MESSAGE(CDMESSAGE='MPL_ALLREDUCE: BUFFER NOT CONTIGUOUS ',LDABORT=LLABORT)
+!  ENDIF
+!ENDIF
+!#endif
+
+!IF (LLREPRODSUM) THEN
+!-- Near reproducible summation (independent of number of threads)
+
+!  IP2=0
+!  DO
+!    IP2=IP2+1
+!    IF(2**IP2 >= MPL_NUMPROC) EXIT
+!  ENDDO
+
+!  IMSENT=0
+!  DO JSTAGE=IP2,1,-1
+!    WRITE(0,*) 'STAGE ',JSTAGE
+!    ITAG  = 2001+JSTAGE
+!    II    = 2**JSTAGE
+!    IHALF = II/2
+!    ISEND = MPL_RANK - IHALF
+!    IF(ISEND > 0 .AND. MPL_RANK <= II) THEN
+!      IMSENT=IMSENT+1
+!      CALL MPL_SEND(PSENDBUF,KDEST=ISEND,KCOMM=ICOMM,KTAG=ITAG,KERROR=IERROR,&
+!       &KMP_TYPE=JP_NON_BLOCKING_STANDARD,KREQUEST=ISREQ(IMSENT),CDSTRING='MPLS_SEND')
+!      write(0,*) 'I SEND TO ',MPL_RANK,ISEND
+!    ENDIF
+!    IRECV=MPL_RANK + IHALF
+!    IF(IRECV <=MPL_NUMPROC .AND. MPL_RANK <= IHALF) THEN
+!      CALL MPL_RECV(ZRECVBUF,KSOURCE=IRECV,KCOMM=ICOMM,KTAG=ITAG,&
+!       &KERROR=IERROR,KOUNT=ICOUNT)
+!      write(0,*) 'I RECV FROM ',MPL_RANK,IRECV
+!      PSENDBUF(:) = PSENDBUF(:) + ZRECVBUF(:)
+!    ENDIF
+!  ENDDO
+!  IF(IMSENT > 0) THEN
+!    CALL MPL_WAIT(KREQUEST=ISREQ(1:IMSENT),CDSTRING='MPLS_SEND')
+!  ENDIF
+!  IF (MPL_RANK == 1) THEN
+!    ZRECVBUF(:) = PSENDBUF(:)
+!  ENDIF
+!  write(0,*) 'enter broadcast '
+!  CALL MPL_BROADCAST(ZRECVBUF,KTAG=ITAG,KCOMM=ICOMM,KROOT=1,KERROR=IERROR)
+!  write(0,*) 'exit broadcast '
+
+!ELSE
+  IF ( MPL_NUMPROC > 1 ) &
+  CALL MPI_ALLREDUCE(PSENDBUF,ZRECVBUF,ISENDCOUNT,INT(MPI_REAL8), &
+                  &  IOPER,ICOMM,IERROR)
+
+  IF(LMPLSTATS) THEN
+    CALL MPL_SENDSTATS(ISENDCOUNT,INT(MPI_REAL8))
+    CALL MPL_RECVSTATS(ISENDCOUNT,INT(MPI_REAL8))
+  ENDIF
+
+!ENDIF
+
+IF(MPL_OUTPUT > 1 )THEN
+  WRITE(MPL_UNIT,'(A,5I8)') ' MPL_ALLREDUCE ',ISENDCOUNT,ICOMM,IOPER
+ENDIF
+IF(PRESENT(KERROR)) THEN
+  KERROR=IERROR
+ELSE
+  IF(IERROR /= 0 ) CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT)
+ENDIF
+
+IF ( MPL_NUMPROC > 1 ) &
+PSENDBUF(:,:) = ZRECVBUF(:,:)
+
+END SUBROUTINE MPL_ALLREDUCE_REAL8_2D
+
+
 
 SUBROUTINE MPL_ALLREDUCE_REAL4(PSENDBUF,CDOPER,LDREPROD, &
                             & KCOMM,KERROR,CDSTRING)
@@ -568,6 +698,115 @@ SUBROUTINE MPL_ALLREDUCE_REAL4(PSENDBUF,CDOPER,LDREPROD, &
 
 END SUBROUTINE MPL_ALLREDUCE_REAL4
 
+SUBROUTINE MPL_ALLREDUCE_REAL4_2D(PSENDBUF,CDOPER,LDREPROD, &
+                            & KCOMM,KERROR,CDSTRING)
+
+
+#ifdef USE_8_BYTE_WORDS
+  USE MPI4TO8, ONLY : &
+    MPI_ALLREDUCE => MPI_ALLREDUCE8
+#endif
+
+REAL(KIND=JPRM),INTENT(INOUT)        :: PSENDBUF(:,:)
+CHARACTER(LEN=*),INTENT(IN)    :: CDOPER
+LOGICAL,INTENT(IN),OPTIONAL :: LDREPROD
+INTEGER(KIND=JPIM),INTENT(IN),OPTIONAL :: KCOMM
+INTEGER(KIND=JPIM),INTENT(OUT),OPTIONAL :: KERROR
+CHARACTER(LEN=*),INTENT(IN),OPTIONAL :: CDSTRING
+REAL(KIND=JPRM)            :: ZRECVBUF(SIZE(PSENDBUF(:,1)),SIZE(PSENDBUF(1,:)))
+INTEGER(KIND=JPIM) IPROC, ITAG, ICOUNT
+LOGICAL LLREPRODSUM
+INTEGER(KIND=JPIM) :: ISENDCOUNT,ICOMM,IERROR,IOPER
+INTEGER(KIND=JPIM) :: ITID
+IERROR = 0
+ITID = OML_MY_THREAD()
+LLREPRODSUM = .FALSE.
+
+IF(MPL_NUMPROC < 1) CALL MPL_MESSAGE( &
+  & CDMESSAGE='MPL_ALLREDUCE: MPL NOT INITIALISED ',LDABORT=LLABORT)
+
+IF(CDOPER(1:3) == 'MAX' .OR. CDOPER(1:3) == 'max' ) THEN
+  IOPER = MPI_MAX
+ELSEIF(CDOPER(1:3) == 'MIN' .OR. CDOPER(1:3) == 'min' ) THEN
+  IOPER = MPI_MIN
+ELSEIF(CDOPER(1:3) == 'SUM' .OR. CDOPER(1:3) == 'sum' ) THEN
+  IOPER = MPI_SUM
+  IF (PRESENT(LDREPROD)) THEN
+    LLREPRODSUM = LDREPROD
+  ELSE
+    CALL MPL_MESSAGE(IERROR,&
+     & 'MPL_ALLREDUCE: SUMMATION OPERATOR NOT REPRODUCIBLE IN REAL MODE',&
+     & CDSTRING,LDABORT=LLABORT)
+  ENDIF
+ELSE
+  CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE: ERROR UNKNOWN OPERATOR',&
+   & CDSTRING,LDABORT=LLABORT)
+ENDIF
+
+IF(PRESENT(KCOMM)) THEN
+  ICOMM=KCOMM
+ELSE
+  ICOMM=MPL_COMM_OML(ITID)
+ENDIF
+
+ISENDCOUNT = SIZE(PSENDBUF)
+!#ifndef NAG
+!IF (ISENDCOUNT > 0) THEN
+!  IF( (LOC(PSENDBUF(UBOUND(PSENDBUF,1)))-LOC(PSENDBUF(LBOUND(PSENDBUF,1)))) /= 4_JPIB*(ISENDCOUNT - 1) ) THEN
+!    CALL MPL_MESSAGE(CDMESSAGE='MPL_ALLREDUCE: BUFFER NOT CONTIGUOUS ',LDABORT=LLABORT)
+!  ENDIF
+!ENDIF
+!#endif
+
+!IF (LLREPRODSUM) THEN
+!-- Near reproducible summation
+!  ITAG = 2001
+!  IF (MPL_RANK == 1) THEN
+!    DO IPROC=2,MPL_NUMPROC
+!      CALL MPL_RECV(ZRECVBUF,KSOURCE=IPROC,KCOMM=ICOMM,KTAG=ITAG,&
+!        &KERROR=IERROR,KOUNT=ICOUNT)
+!      IF (ICOUNT /= ISENDCOUNT) THEN
+!        WRITE(MPL_ERRUNIT,'(A,I10,A,I6,A,I10)')&
+!        & 'MPL_ALLREDUCE: RECEIVED UNEXPECTED NUMBER OF ELEMENTS ', &
+!        & ICOUNT,' FROM PROC ',IPROC,'. EXPECTED=',ISENDCOUNT
+!        CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT)
+!      ENDIF
+!      PSENDBUF(:) = PSENDBUF(:) + ZRECVBUF(:)
+!    ENDDO
+!    ZRECVBUF(:) = PSENDBUF(:)
+!  ELSE
+!    CALL MPL_SEND(PSENDBUF,KDEST=1,KCOMM=ICOMM,KTAG=ITAG,KERROR=IERROR,&
+!      &KMP_TYPE=JP_BLOCKING_STANDARD,CDSTRING='MPLS_SEND')
+!  ENDIF
+!  ITAG = ITAG + 1
+!  CALL MPL_BROADCAST(ZRECVBUF,KTAG=ITAG,KCOMM=ICOMM,KROOT=1,KERROR=IERROR)
+!ELSE
+IF ( MPL_NUMPROC > 1 ) &
+CALL MPI_ALLREDUCE(PSENDBUF,ZRECVBUF,ISENDCOUNT,INT(MPI_REAL4), &
+                  &  IOPER,ICOMM,IERROR)
+
+IF(LMPLSTATS) THEN
+  CALL MPL_SENDSTATS(ISENDCOUNT,INT(MPI_REAL4))
+  CALL MPL_RECVSTATS(ISENDCOUNT,INT(MPI_REAL4))
+ENDIF
+
+!ENDIF
+
+IF(MPL_OUTPUT > 1 )THEN
+  WRITE(MPL_UNIT,'(A,5I8)') ' MPL_ALLREDUCE ',ISENDCOUNT,ICOMM,IOPER
+ENDIF
+IF(PRESENT(KERROR)) THEN
+  KERROR=IERROR
+ELSE
+  IF(IERROR /= 0 ) CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT)
+ENDIF
+
+IF ( MPL_NUMPROC > 1 ) &
+PSENDBUF(:,:) = ZRECVBUF(:,:)
+
+END SUBROUTINE MPL_ALLREDUCE_REAL4_2D
+
+
 END MODULE MPL_ALLREDUCE_MOD