From 20a36bb2658aa7676c66af478cf8cacbe449f955 Mon Sep 17 00:00:00 2001 From: Iria Ayan-Miguez <iria.ayan-miguez_2@ecmwf.int> Date: Wed, 27 Mar 2024 11:31:13 +0000 Subject: [PATCH] Add a 2D MPL_ALLREDUCE interface --- src/fiat/mpl/internal/mpi4to8_m.F90 | 67 +++++- src/fiat/mpl/internal/mpl_allreduce_mod.F90 | 241 +++++++++++++++++++- 2 files changed, 306 insertions(+), 2 deletions(-) diff --git a/src/fiat/mpl/internal/mpi4to8_m.F90 b/src/fiat/mpl/internal/mpi4to8_m.F90 index 1c2e4d1c..f841373e 100644 --- a/src/fiat/mpl/internal/mpi4to8_m.F90 +++ b/src/fiat/mpl/internal/mpi4to8_m.F90 @@ -26,7 +26,8 @@ MODULE MPI4TO8_M INTERFACE MPI_ALLREDUCE8 MODULE PROCEDURE MPI_ALLREDUCE8_R4, MPI_ALLREDUCE8_R8, & - MPI_ALLREDUCE8_I4 + MPI_ALLREDUCE8_I4, MPI_ALLREDUCE8_R4_2D, & + MPI_ALLREDUCE8_R8_2D END INTERFACE MPI_ALLREDUCE8 INTERFACE MPI_ALLTOALLV8 @@ -179,6 +180,43 @@ SUBROUTINE MPI_ALLREDUCE8_R4(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, & END SUBROUTINE MPI_ALLREDUCE8_R4 +SUBROUTINE MPI_ALLREDUCE8_R4_2D(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, & + COMM, IERROR) + + REAL(KIND=JPRM), DIMENSION(:,:), INTENT(IN) :: & + SENDDATA(:,:) + INTEGER(KIND=JPIM), INTENT(IN) :: & + COUNT, DATATYPE, OP, COMM + REAL(KIND=JPRM), DIMENSION(:,:), INTENT(OUT) :: & + RECVDATA(:.:) + INTEGER(KIND=JPIM), INTENT(OUT) :: & + IERROR + + REAL(KIND=8), DIMENSION(:,:), ALLOCATABLE :: & + SENDDATA8, RECVDATA8 + INTEGER(KIND=8) :: & + COUNT8, DATATYPE8, OP8, COMM8, IERROR8 + + ALLOCATE(SENDDATA8(SIZE(SENDDATA))) + ALLOCATE(RECVDATA8(SIZE(RECVDATA))) + + SENDDATA8 = SENDDATA + COUNT8 = COUNT + DATATYPE8 = DATATYPE + OP8 = OP + COMM8 = COMM + + CALL MPI_ALLREDUCE(SENDDATA8, RECVDATA8, COUNT8, DATATYPE8, OP8, COMM8, IERROR8) + + RECVDATA = RECVDATA8 + IERROR = IERROR8 + + DEALLOCATE(SENDDATA8) + DEALLOCATE(RECVDATA8) + +END SUBROUTINE MPI_ALLREDUCE8_R4_2D + + ! --------------------------------------------------------- SUBROUTINE MPI_ALLREDUCE8_R8(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, & COMM, IERROR) @@ -206,6 +244,33 @@ SUBROUTINE MPI_ALLREDUCE8_R8(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, & END SUBROUTINE MPI_ALLREDUCE8_R8 +SUBROUTINE MPI_ALLREDUCE8_R8_2D(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, & + COMM, IERROR) + + REAL(KIND=JPRD), DIMENSION(:,:), INTENT(IN) :: & + SENDDATA(:,:) + INTEGER(KIND=JPIM), INTENT(IN) :: & + COUNT, DATATYPE, OP, COMM + REAL(KIND=JPRD), DIMENSION(:,:), INTENT(OUT) :: & + RECVDATA(:,:) + INTEGER(KIND=JPIM), INTENT(OUT) :: & + IERROR + + INTEGER(KIND=8) :: & + COUNT8, DATATYPE8, OP8, COMM8, IERROR8 + + COUNT8 = COUNT + DATATYPE8 = DATATYPE + OP8 = OP + COMM8 = COMM + + CALL MPI_ALLREDUCE(SENDDATA, RECVDATA, COUNT8, DATATYPE8, OP8, COMM8, IERROR8) + + IERROR = IERROR8 + +END SUBROUTINE MPI_ALLREDUCE8_R8 + + ! --------------------------------------------------------- SUBROUTINE MPI_ALLREDUCE8_I4(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, & COMM, IERROR) diff --git a/src/fiat/mpl/internal/mpl_allreduce_mod.F90 b/src/fiat/mpl/internal/mpl_allreduce_mod.F90 index 1e412ea5..adf83640 100644 --- a/src/fiat/mpl/internal/mpl_allreduce_mod.F90 +++ b/src/fiat/mpl/internal/mpl_allreduce_mod.F90 @@ -85,7 +85,8 @@ MODULE MPL_ALLREDUCE_MOD MODULE PROCEDURE MPL_ALLREDUCE_REAL8, MPL_ALLREDUCE_REAL4, MPL_ALLREDUCE_INT, & MPL_ALLREDUCE_INT8, & MPL_ALLREDUCE_REAL8_SCALAR, MPL_ALLREDUCE_REAL4_SCALAR, & - MPL_ALLREDUCE_INT_SCALAR, MPL_ALLREDUCE_INT8_SCALAR + MPL_ALLREDUCE_INT_SCALAR, MPL_ALLREDUCE_INT8_SCALAR, & + MPL_ALLREDUCE_REAL4_2D, MPL_ALLREDUCE_REAL8_2D END INTERFACE PUBLIC MPL_ALLREDUCE @@ -459,6 +460,135 @@ SUBROUTINE MPL_ALLREDUCE_REAL8(PSENDBUF,CDOPER,LDREPROD, & END SUBROUTINE MPL_ALLREDUCE_REAL8 +SUBROUTINE MPL_ALLREDUCE_REAL8_2D(PSENDBUF,CDOPER,LDREPROD, & + & KCOMM,KERROR,CDSTRING) + + +#ifdef USE_8_BYTE_WORDS + USE MPI4TO8, ONLY : & + MPI_ALLREDUCE => MPI_ALLREDUCE8 +#endif + +REAL(KIND=JPRD),INTENT(INOUT) :: PSENDBUF(:,:) +CHARACTER(LEN=*),INTENT(IN) :: CDOPER +LOGICAL,INTENT(IN),OPTIONAL :: LDREPROD +INTEGER(KIND=JPIM),INTENT(IN),OPTIONAL :: KCOMM +INTEGER(KIND=JPIM),INTENT(OUT),OPTIONAL :: KERROR +CHARACTER(LEN=*),INTENT(IN),OPTIONAL :: CDSTRING +REAL(KIND=JPRD) :: ZRECVBUF(SIZE(PSENDBUF(:,1)),SIZE(PSENDBUF(1,:))) +INTEGER(KIND=JPIM) ITAG, ICOUNT +LOGICAL LLREPRODSUM +INTEGER(KIND=JPIM) :: ISENDCOUNT,ICOMM,IERROR,IOPER +INTEGER(KIND=JPIM) :: IP2,II,IHALF,JSTAGE,ISEND,IRECV,IMSENT +INTEGER(KIND=JPIM) :: ISREQ(MPL_NUMPROC) +INTEGER(KIND=JPIM) :: ITID +IERROR = 0 +ITID = OML_MY_THREAD() +LLREPRODSUM = .FALSE. + +IF(MPL_NUMPROC < 1) CALL MPL_MESSAGE( & + & CDMESSAGE='MPL_ALLREDUCE: MPL NOT INITIALISED ',LDABORT=LLABORT) + +IF(CDOPER(1:3) == 'MAX' .OR. CDOPER(1:3) == 'max' ) THEN + IOPER = MPI_MAX +ELSEIF(CDOPER(1:3) == 'MIN' .OR. CDOPER(1:3) == 'min' ) THEN + IOPER = MPI_MIN +ELSEIF(CDOPER(1:3) == 'SUM' .OR. CDOPER(1:3) == 'sum' ) THEN + IOPER = MPI_SUM + IF (PRESENT(LDREPROD)) THEN + LLREPRODSUM = LDREPROD + ELSE + CALL MPL_MESSAGE(IERROR,& + & 'MPL_ALLREDUCE: SUMMATION OPERATOR NOT REPRODUCIBLE IN REAL MODE',& + & CDSTRING,LDABORT=LLABORT) + ENDIF +ELSE + CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE: ERROR UNKNOWN OPERATOR',& + & CDSTRING,LDABORT=LLABORT) +ENDIF + +IF(PRESENT(KCOMM)) THEN + ICOMM=KCOMM +ELSE + ICOMM=MPL_COMM_OML(ITID) +ENDIF + +ISENDCOUNT = SIZE(PSENDBUF) +!#ifndef NAG +!IF (ISENDCOUNT > 0) THEN +! IF( (LOC(PSENDBUF(UBOUND(PSENDBUF,1)))-LOC(PSENDBUF(LBOUND(PSENDBUF,1)))) /= 8_JPIB*(ISENDCOUNT - 1) ) THEN +! CALL MPL_MESSAGE(CDMESSAGE='MPL_ALLREDUCE: BUFFER NOT CONTIGUOUS ',LDABORT=LLABORT) +! ENDIF +!ENDIF +!#endif + +!IF (LLREPRODSUM) THEN +!-- Near reproducible summation (independent of number of threads) + +! IP2=0 +! DO +! IP2=IP2+1 +! IF(2**IP2 >= MPL_NUMPROC) EXIT +! ENDDO + +! IMSENT=0 +! DO JSTAGE=IP2,1,-1 +! WRITE(0,*) 'STAGE ',JSTAGE +! ITAG = 2001+JSTAGE +! II = 2**JSTAGE +! IHALF = II/2 +! ISEND = MPL_RANK - IHALF +! IF(ISEND > 0 .AND. MPL_RANK <= II) THEN +! IMSENT=IMSENT+1 +! CALL MPL_SEND(PSENDBUF,KDEST=ISEND,KCOMM=ICOMM,KTAG=ITAG,KERROR=IERROR,& +! &KMP_TYPE=JP_NON_BLOCKING_STANDARD,KREQUEST=ISREQ(IMSENT),CDSTRING='MPLS_SEND') +! write(0,*) 'I SEND TO ',MPL_RANK,ISEND +! ENDIF +! IRECV=MPL_RANK + IHALF +! IF(IRECV <=MPL_NUMPROC .AND. MPL_RANK <= IHALF) THEN +! CALL MPL_RECV(ZRECVBUF,KSOURCE=IRECV,KCOMM=ICOMM,KTAG=ITAG,& +! &KERROR=IERROR,KOUNT=ICOUNT) +! write(0,*) 'I RECV FROM ',MPL_RANK,IRECV +! PSENDBUF(:) = PSENDBUF(:) + ZRECVBUF(:) +! ENDIF +! ENDDO +! IF(IMSENT > 0) THEN +! CALL MPL_WAIT(KREQUEST=ISREQ(1:IMSENT),CDSTRING='MPLS_SEND') +! ENDIF +! IF (MPL_RANK == 1) THEN +! ZRECVBUF(:) = PSENDBUF(:) +! ENDIF +! write(0,*) 'enter broadcast ' +! CALL MPL_BROADCAST(ZRECVBUF,KTAG=ITAG,KCOMM=ICOMM,KROOT=1,KERROR=IERROR) +! write(0,*) 'exit broadcast ' + +!ELSE + IF ( MPL_NUMPROC > 1 ) & + CALL MPI_ALLREDUCE(PSENDBUF,ZRECVBUF,ISENDCOUNT,INT(MPI_REAL8), & + & IOPER,ICOMM,IERROR) + + IF(LMPLSTATS) THEN + CALL MPL_SENDSTATS(ISENDCOUNT,INT(MPI_REAL8)) + CALL MPL_RECVSTATS(ISENDCOUNT,INT(MPI_REAL8)) + ENDIF + +!ENDIF + +IF(MPL_OUTPUT > 1 )THEN + WRITE(MPL_UNIT,'(A,5I8)') ' MPL_ALLREDUCE ',ISENDCOUNT,ICOMM,IOPER +ENDIF +IF(PRESENT(KERROR)) THEN + KERROR=IERROR +ELSE + IF(IERROR /= 0 ) CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT) +ENDIF + +IF ( MPL_NUMPROC > 1 ) & +PSENDBUF(:,:) = ZRECVBUF(:,:) + +END SUBROUTINE MPL_ALLREDUCE_REAL8_2D + + SUBROUTINE MPL_ALLREDUCE_REAL4(PSENDBUF,CDOPER,LDREPROD, & & KCOMM,KERROR,CDSTRING) @@ -568,6 +698,115 @@ SUBROUTINE MPL_ALLREDUCE_REAL4(PSENDBUF,CDOPER,LDREPROD, & END SUBROUTINE MPL_ALLREDUCE_REAL4 +SUBROUTINE MPL_ALLREDUCE_REAL4_2D(PSENDBUF,CDOPER,LDREPROD, & + & KCOMM,KERROR,CDSTRING) + + +#ifdef USE_8_BYTE_WORDS + USE MPI4TO8, ONLY : & + MPI_ALLREDUCE => MPI_ALLREDUCE8 +#endif + +REAL(KIND=JPRM),INTENT(INOUT) :: PSENDBUF(:,:) +CHARACTER(LEN=*),INTENT(IN) :: CDOPER +LOGICAL,INTENT(IN),OPTIONAL :: LDREPROD +INTEGER(KIND=JPIM),INTENT(IN),OPTIONAL :: KCOMM +INTEGER(KIND=JPIM),INTENT(OUT),OPTIONAL :: KERROR +CHARACTER(LEN=*),INTENT(IN),OPTIONAL :: CDSTRING +REAL(KIND=JPRM) :: ZRECVBUF(SIZE(PSENDBUF(:,1)),SIZE(PSENDBUF(1,:))) +INTEGER(KIND=JPIM) IPROC, ITAG, ICOUNT +LOGICAL LLREPRODSUM +INTEGER(KIND=JPIM) :: ISENDCOUNT,ICOMM,IERROR,IOPER +INTEGER(KIND=JPIM) :: ITID +IERROR = 0 +ITID = OML_MY_THREAD() +LLREPRODSUM = .FALSE. + +IF(MPL_NUMPROC < 1) CALL MPL_MESSAGE( & + & CDMESSAGE='MPL_ALLREDUCE: MPL NOT INITIALISED ',LDABORT=LLABORT) + +IF(CDOPER(1:3) == 'MAX' .OR. CDOPER(1:3) == 'max' ) THEN + IOPER = MPI_MAX +ELSEIF(CDOPER(1:3) == 'MIN' .OR. CDOPER(1:3) == 'min' ) THEN + IOPER = MPI_MIN +ELSEIF(CDOPER(1:3) == 'SUM' .OR. CDOPER(1:3) == 'sum' ) THEN + IOPER = MPI_SUM + IF (PRESENT(LDREPROD)) THEN + LLREPRODSUM = LDREPROD + ELSE + CALL MPL_MESSAGE(IERROR,& + & 'MPL_ALLREDUCE: SUMMATION OPERATOR NOT REPRODUCIBLE IN REAL MODE',& + & CDSTRING,LDABORT=LLABORT) + ENDIF +ELSE + CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE: ERROR UNKNOWN OPERATOR',& + & CDSTRING,LDABORT=LLABORT) +ENDIF + +IF(PRESENT(KCOMM)) THEN + ICOMM=KCOMM +ELSE + ICOMM=MPL_COMM_OML(ITID) +ENDIF + +ISENDCOUNT = SIZE(PSENDBUF) +!#ifndef NAG +!IF (ISENDCOUNT > 0) THEN +! IF( (LOC(PSENDBUF(UBOUND(PSENDBUF,1)))-LOC(PSENDBUF(LBOUND(PSENDBUF,1)))) /= 4_JPIB*(ISENDCOUNT - 1) ) THEN +! CALL MPL_MESSAGE(CDMESSAGE='MPL_ALLREDUCE: BUFFER NOT CONTIGUOUS ',LDABORT=LLABORT) +! ENDIF +!ENDIF +!#endif + +!IF (LLREPRODSUM) THEN +!-- Near reproducible summation +! ITAG = 2001 +! IF (MPL_RANK == 1) THEN +! DO IPROC=2,MPL_NUMPROC +! CALL MPL_RECV(ZRECVBUF,KSOURCE=IPROC,KCOMM=ICOMM,KTAG=ITAG,& +! &KERROR=IERROR,KOUNT=ICOUNT) +! IF (ICOUNT /= ISENDCOUNT) THEN +! WRITE(MPL_ERRUNIT,'(A,I10,A,I6,A,I10)')& +! & 'MPL_ALLREDUCE: RECEIVED UNEXPECTED NUMBER OF ELEMENTS ', & +! & ICOUNT,' FROM PROC ',IPROC,'. EXPECTED=',ISENDCOUNT +! CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT) +! ENDIF +! PSENDBUF(:) = PSENDBUF(:) + ZRECVBUF(:) +! ENDDO +! ZRECVBUF(:) = PSENDBUF(:) +! ELSE +! CALL MPL_SEND(PSENDBUF,KDEST=1,KCOMM=ICOMM,KTAG=ITAG,KERROR=IERROR,& +! &KMP_TYPE=JP_BLOCKING_STANDARD,CDSTRING='MPLS_SEND') +! ENDIF +! ITAG = ITAG + 1 +! CALL MPL_BROADCAST(ZRECVBUF,KTAG=ITAG,KCOMM=ICOMM,KROOT=1,KERROR=IERROR) +!ELSE +IF ( MPL_NUMPROC > 1 ) & +CALL MPI_ALLREDUCE(PSENDBUF,ZRECVBUF,ISENDCOUNT,INT(MPI_REAL4), & + & IOPER,ICOMM,IERROR) + +IF(LMPLSTATS) THEN + CALL MPL_SENDSTATS(ISENDCOUNT,INT(MPI_REAL4)) + CALL MPL_RECVSTATS(ISENDCOUNT,INT(MPI_REAL4)) +ENDIF + +!ENDIF + +IF(MPL_OUTPUT > 1 )THEN + WRITE(MPL_UNIT,'(A,5I8)') ' MPL_ALLREDUCE ',ISENDCOUNT,ICOMM,IOPER +ENDIF +IF(PRESENT(KERROR)) THEN + KERROR=IERROR +ELSE + IF(IERROR /= 0 ) CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT) +ENDIF + +IF ( MPL_NUMPROC > 1 ) & +PSENDBUF(:,:) = ZRECVBUF(:,:) + +END SUBROUTINE MPL_ALLREDUCE_REAL4_2D + + END MODULE MPL_ALLREDUCE_MOD