Skip to content

Commit

Permalink
Add a 2D MPL_ALLREDUCE interface (#19)
Browse files Browse the repository at this point in the history
Co-authored-by: Iria Ayan-Miguez <[email protected]>
  • Loading branch information
towil1 and Iria Ayan-Miguez authored Apr 16, 2024
1 parent a786659 commit d0b0384
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 2 deletions.
67 changes: 66 additions & 1 deletion src/fiat/mpl/internal/mpi4to8_m.F90
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ MODULE MPI4TO8_M

INTERFACE MPI_ALLREDUCE8
MODULE PROCEDURE MPI_ALLREDUCE8_R4, MPI_ALLREDUCE8_R8, &
MPI_ALLREDUCE8_I4
MPI_ALLREDUCE8_I4, MPI_ALLREDUCE8_R4_2D, &
MPI_ALLREDUCE8_R8_2D
END INTERFACE MPI_ALLREDUCE8

INTERFACE MPI_ALLTOALLV8
Expand Down Expand Up @@ -179,6 +180,43 @@ SUBROUTINE MPI_ALLREDUCE8_R4(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &

END SUBROUTINE MPI_ALLREDUCE8_R4

SUBROUTINE MPI_ALLREDUCE8_R4_2D(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
COMM, IERROR)

REAL(KIND=JPRM), DIMENSION(:,:), INTENT(IN) :: &
SENDDATA(:,:)
INTEGER(KIND=JPIM), INTENT(IN) :: &
COUNT, DATATYPE, OP, COMM
REAL(KIND=JPRM), DIMENSION(:,:), INTENT(OUT) :: &
RECVDATA(:.:)
INTEGER(KIND=JPIM), INTENT(OUT) :: &
IERROR

REAL(KIND=8), DIMENSION(:,:), ALLOCATABLE :: &
SENDDATA8, RECVDATA8
INTEGER(KIND=8) :: &
COUNT8, DATATYPE8, OP8, COMM8, IERROR8

ALLOCATE(SENDDATA8(SIZE(SENDDATA)))
ALLOCATE(RECVDATA8(SIZE(RECVDATA)))

SENDDATA8 = SENDDATA
COUNT8 = COUNT
DATATYPE8 = DATATYPE
OP8 = OP
COMM8 = COMM

CALL MPI_ALLREDUCE(SENDDATA8, RECVDATA8, COUNT8, DATATYPE8, OP8, COMM8, IERROR8)

RECVDATA = RECVDATA8
IERROR = IERROR8

DEALLOCATE(SENDDATA8)
DEALLOCATE(RECVDATA8)

END SUBROUTINE MPI_ALLREDUCE8_R4_2D


! ---------------------------------------------------------
SUBROUTINE MPI_ALLREDUCE8_R8(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
COMM, IERROR)
Expand Down Expand Up @@ -206,6 +244,33 @@ SUBROUTINE MPI_ALLREDUCE8_R8(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &

END SUBROUTINE MPI_ALLREDUCE8_R8

SUBROUTINE MPI_ALLREDUCE8_R8_2D(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
COMM, IERROR)

REAL(KIND=JPRD), DIMENSION(:,:), INTENT(IN) :: &
SENDDATA(:,:)
INTEGER(KIND=JPIM), INTENT(IN) :: &
COUNT, DATATYPE, OP, COMM
REAL(KIND=JPRD), DIMENSION(:,:), INTENT(OUT) :: &
RECVDATA(:,:)
INTEGER(KIND=JPIM), INTENT(OUT) :: &
IERROR

INTEGER(KIND=8) :: &
COUNT8, DATATYPE8, OP8, COMM8, IERROR8

COUNT8 = COUNT
DATATYPE8 = DATATYPE
OP8 = OP
COMM8 = COMM

CALL MPI_ALLREDUCE(SENDDATA, RECVDATA, COUNT8, DATATYPE8, OP8, COMM8, IERROR8)

IERROR = IERROR8

END SUBROUTINE MPI_ALLREDUCE8_R8


! ---------------------------------------------------------
SUBROUTINE MPI_ALLREDUCE8_I4(SENDDATA, RECVDATA, COUNT, DATATYPE, OP, &
COMM, IERROR)
Expand Down
241 changes: 240 additions & 1 deletion src/fiat/mpl/internal/mpl_allreduce_mod.F90
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ MODULE MPL_ALLREDUCE_MOD
MODULE PROCEDURE MPL_ALLREDUCE_REAL8, MPL_ALLREDUCE_REAL4, MPL_ALLREDUCE_INT, &
MPL_ALLREDUCE_INT8, &
MPL_ALLREDUCE_REAL8_SCALAR, MPL_ALLREDUCE_REAL4_SCALAR, &
MPL_ALLREDUCE_INT_SCALAR, MPL_ALLREDUCE_INT8_SCALAR
MPL_ALLREDUCE_INT_SCALAR, MPL_ALLREDUCE_INT8_SCALAR, &
MPL_ALLREDUCE_REAL4_2D, MPL_ALLREDUCE_REAL8_2D
END INTERFACE

PUBLIC MPL_ALLREDUCE
Expand Down Expand Up @@ -459,6 +460,135 @@ SUBROUTINE MPL_ALLREDUCE_REAL8(PSENDBUF,CDOPER,LDREPROD, &

END SUBROUTINE MPL_ALLREDUCE_REAL8

SUBROUTINE MPL_ALLREDUCE_REAL8_2D(PSENDBUF,CDOPER,LDREPROD, &
& KCOMM,KERROR,CDSTRING)


#ifdef USE_8_BYTE_WORDS
USE MPI4TO8, ONLY : &
MPI_ALLREDUCE => MPI_ALLREDUCE8
#endif

REAL(KIND=JPRD),INTENT(INOUT) :: PSENDBUF(:,:)
CHARACTER(LEN=*),INTENT(IN) :: CDOPER
LOGICAL,INTENT(IN),OPTIONAL :: LDREPROD
INTEGER(KIND=JPIM),INTENT(IN),OPTIONAL :: KCOMM
INTEGER(KIND=JPIM),INTENT(OUT),OPTIONAL :: KERROR
CHARACTER(LEN=*),INTENT(IN),OPTIONAL :: CDSTRING
REAL(KIND=JPRD) :: ZRECVBUF(SIZE(PSENDBUF(:,1)),SIZE(PSENDBUF(1,:)))
INTEGER(KIND=JPIM) ITAG, ICOUNT
LOGICAL LLREPRODSUM
INTEGER(KIND=JPIM) :: ISENDCOUNT,ICOMM,IERROR,IOPER
INTEGER(KIND=JPIM) :: IP2,II,IHALF,JSTAGE,ISEND,IRECV,IMSENT
INTEGER(KIND=JPIM) :: ISREQ(MPL_NUMPROC)
INTEGER(KIND=JPIM) :: ITID
IERROR = 0
ITID = OML_MY_THREAD()
LLREPRODSUM = .FALSE.

IF(MPL_NUMPROC < 1) CALL MPL_MESSAGE( &
& CDMESSAGE='MPL_ALLREDUCE: MPL NOT INITIALISED ',LDABORT=LLABORT)

IF(CDOPER(1:3) == 'MAX' .OR. CDOPER(1:3) == 'max' ) THEN
IOPER = MPI_MAX
ELSEIF(CDOPER(1:3) == 'MIN' .OR. CDOPER(1:3) == 'min' ) THEN
IOPER = MPI_MIN
ELSEIF(CDOPER(1:3) == 'SUM' .OR. CDOPER(1:3) == 'sum' ) THEN
IOPER = MPI_SUM
IF (PRESENT(LDREPROD)) THEN
LLREPRODSUM = LDREPROD
ELSE
CALL MPL_MESSAGE(IERROR,&
& 'MPL_ALLREDUCE: SUMMATION OPERATOR NOT REPRODUCIBLE IN REAL MODE',&
& CDSTRING,LDABORT=LLABORT)
ENDIF
ELSE
CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE: ERROR UNKNOWN OPERATOR',&
& CDSTRING,LDABORT=LLABORT)
ENDIF

IF(PRESENT(KCOMM)) THEN
ICOMM=KCOMM
ELSE
ICOMM=MPL_COMM_OML(ITID)
ENDIF

ISENDCOUNT = SIZE(PSENDBUF)
!#ifndef NAG
!IF (ISENDCOUNT > 0) THEN
! IF( (LOC(PSENDBUF(UBOUND(PSENDBUF,1)))-LOC(PSENDBUF(LBOUND(PSENDBUF,1)))) /= 8_JPIB*(ISENDCOUNT - 1) ) THEN
! CALL MPL_MESSAGE(CDMESSAGE='MPL_ALLREDUCE: BUFFER NOT CONTIGUOUS ',LDABORT=LLABORT)
! ENDIF
!ENDIF
!#endif

!IF (LLREPRODSUM) THEN
!-- Near reproducible summation (independent of number of threads)

! IP2=0
! DO
! IP2=IP2+1
! IF(2**IP2 >= MPL_NUMPROC) EXIT
! ENDDO

! IMSENT=0
! DO JSTAGE=IP2,1,-1
! WRITE(0,*) 'STAGE ',JSTAGE
! ITAG = 2001+JSTAGE
! II = 2**JSTAGE
! IHALF = II/2
! ISEND = MPL_RANK - IHALF
! IF(ISEND > 0 .AND. MPL_RANK <= II) THEN
! IMSENT=IMSENT+1
! CALL MPL_SEND(PSENDBUF,KDEST=ISEND,KCOMM=ICOMM,KTAG=ITAG,KERROR=IERROR,&
! &KMP_TYPE=JP_NON_BLOCKING_STANDARD,KREQUEST=ISREQ(IMSENT),CDSTRING='MPLS_SEND')
! write(0,*) 'I SEND TO ',MPL_RANK,ISEND
! ENDIF
! IRECV=MPL_RANK + IHALF
! IF(IRECV <=MPL_NUMPROC .AND. MPL_RANK <= IHALF) THEN
! CALL MPL_RECV(ZRECVBUF,KSOURCE=IRECV,KCOMM=ICOMM,KTAG=ITAG,&
! &KERROR=IERROR,KOUNT=ICOUNT)
! write(0,*) 'I RECV FROM ',MPL_RANK,IRECV
! PSENDBUF(:) = PSENDBUF(:) + ZRECVBUF(:)
! ENDIF
! ENDDO
! IF(IMSENT > 0) THEN
! CALL MPL_WAIT(KREQUEST=ISREQ(1:IMSENT),CDSTRING='MPLS_SEND')
! ENDIF
! IF (MPL_RANK == 1) THEN
! ZRECVBUF(:) = PSENDBUF(:)
! ENDIF
! write(0,*) 'enter broadcast '
! CALL MPL_BROADCAST(ZRECVBUF,KTAG=ITAG,KCOMM=ICOMM,KROOT=1,KERROR=IERROR)
! write(0,*) 'exit broadcast '

!ELSE
IF ( MPL_NUMPROC > 1 ) &
CALL MPI_ALLREDUCE(PSENDBUF,ZRECVBUF,ISENDCOUNT,INT(MPI_REAL8), &
& IOPER,ICOMM,IERROR)

IF(LMPLSTATS) THEN
CALL MPL_SENDSTATS(ISENDCOUNT,INT(MPI_REAL8))
CALL MPL_RECVSTATS(ISENDCOUNT,INT(MPI_REAL8))
ENDIF

!ENDIF

IF(MPL_OUTPUT > 1 )THEN
WRITE(MPL_UNIT,'(A,5I8)') ' MPL_ALLREDUCE ',ISENDCOUNT,ICOMM,IOPER
ENDIF
IF(PRESENT(KERROR)) THEN
KERROR=IERROR
ELSE
IF(IERROR /= 0 ) CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT)
ENDIF

IF ( MPL_NUMPROC > 1 ) &
PSENDBUF(:,:) = ZRECVBUF(:,:)

END SUBROUTINE MPL_ALLREDUCE_REAL8_2D



SUBROUTINE MPL_ALLREDUCE_REAL4(PSENDBUF,CDOPER,LDREPROD, &
& KCOMM,KERROR,CDSTRING)
Expand Down Expand Up @@ -568,6 +698,115 @@ SUBROUTINE MPL_ALLREDUCE_REAL4(PSENDBUF,CDOPER,LDREPROD, &

END SUBROUTINE MPL_ALLREDUCE_REAL4

SUBROUTINE MPL_ALLREDUCE_REAL4_2D(PSENDBUF,CDOPER,LDREPROD, &
& KCOMM,KERROR,CDSTRING)


#ifdef USE_8_BYTE_WORDS
USE MPI4TO8, ONLY : &
MPI_ALLREDUCE => MPI_ALLREDUCE8
#endif

REAL(KIND=JPRM),INTENT(INOUT) :: PSENDBUF(:,:)
CHARACTER(LEN=*),INTENT(IN) :: CDOPER
LOGICAL,INTENT(IN),OPTIONAL :: LDREPROD
INTEGER(KIND=JPIM),INTENT(IN),OPTIONAL :: KCOMM
INTEGER(KIND=JPIM),INTENT(OUT),OPTIONAL :: KERROR
CHARACTER(LEN=*),INTENT(IN),OPTIONAL :: CDSTRING
REAL(KIND=JPRM) :: ZRECVBUF(SIZE(PSENDBUF(:,1)),SIZE(PSENDBUF(1,:)))
INTEGER(KIND=JPIM) IPROC, ITAG, ICOUNT
LOGICAL LLREPRODSUM
INTEGER(KIND=JPIM) :: ISENDCOUNT,ICOMM,IERROR,IOPER
INTEGER(KIND=JPIM) :: ITID
IERROR = 0
ITID = OML_MY_THREAD()
LLREPRODSUM = .FALSE.

IF(MPL_NUMPROC < 1) CALL MPL_MESSAGE( &
& CDMESSAGE='MPL_ALLREDUCE: MPL NOT INITIALISED ',LDABORT=LLABORT)

IF(CDOPER(1:3) == 'MAX' .OR. CDOPER(1:3) == 'max' ) THEN
IOPER = MPI_MAX
ELSEIF(CDOPER(1:3) == 'MIN' .OR. CDOPER(1:3) == 'min' ) THEN
IOPER = MPI_MIN
ELSEIF(CDOPER(1:3) == 'SUM' .OR. CDOPER(1:3) == 'sum' ) THEN
IOPER = MPI_SUM
IF (PRESENT(LDREPROD)) THEN
LLREPRODSUM = LDREPROD
ELSE
CALL MPL_MESSAGE(IERROR,&
& 'MPL_ALLREDUCE: SUMMATION OPERATOR NOT REPRODUCIBLE IN REAL MODE',&
& CDSTRING,LDABORT=LLABORT)
ENDIF
ELSE
CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE: ERROR UNKNOWN OPERATOR',&
& CDSTRING,LDABORT=LLABORT)
ENDIF

IF(PRESENT(KCOMM)) THEN
ICOMM=KCOMM
ELSE
ICOMM=MPL_COMM_OML(ITID)
ENDIF

ISENDCOUNT = SIZE(PSENDBUF)
!#ifndef NAG
!IF (ISENDCOUNT > 0) THEN
! IF( (LOC(PSENDBUF(UBOUND(PSENDBUF,1)))-LOC(PSENDBUF(LBOUND(PSENDBUF,1)))) /= 4_JPIB*(ISENDCOUNT - 1) ) THEN
! CALL MPL_MESSAGE(CDMESSAGE='MPL_ALLREDUCE: BUFFER NOT CONTIGUOUS ',LDABORT=LLABORT)
! ENDIF
!ENDIF
!#endif

!IF (LLREPRODSUM) THEN
!-- Near reproducible summation
! ITAG = 2001
! IF (MPL_RANK == 1) THEN
! DO IPROC=2,MPL_NUMPROC
! CALL MPL_RECV(ZRECVBUF,KSOURCE=IPROC,KCOMM=ICOMM,KTAG=ITAG,&
! &KERROR=IERROR,KOUNT=ICOUNT)
! IF (ICOUNT /= ISENDCOUNT) THEN
! WRITE(MPL_ERRUNIT,'(A,I10,A,I6,A,I10)')&
! & 'MPL_ALLREDUCE: RECEIVED UNEXPECTED NUMBER OF ELEMENTS ', &
! & ICOUNT,' FROM PROC ',IPROC,'. EXPECTED=',ISENDCOUNT
! CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT)
! ENDIF
! PSENDBUF(:) = PSENDBUF(:) + ZRECVBUF(:)
! ENDDO
! ZRECVBUF(:) = PSENDBUF(:)
! ELSE
! CALL MPL_SEND(PSENDBUF,KDEST=1,KCOMM=ICOMM,KTAG=ITAG,KERROR=IERROR,&
! &KMP_TYPE=JP_BLOCKING_STANDARD,CDSTRING='MPLS_SEND')
! ENDIF
! ITAG = ITAG + 1
! CALL MPL_BROADCAST(ZRECVBUF,KTAG=ITAG,KCOMM=ICOMM,KROOT=1,KERROR=IERROR)
!ELSE
IF ( MPL_NUMPROC > 1 ) &
CALL MPI_ALLREDUCE(PSENDBUF,ZRECVBUF,ISENDCOUNT,INT(MPI_REAL4), &
& IOPER,ICOMM,IERROR)

IF(LMPLSTATS) THEN
CALL MPL_SENDSTATS(ISENDCOUNT,INT(MPI_REAL4))
CALL MPL_RECVSTATS(ISENDCOUNT,INT(MPI_REAL4))
ENDIF

!ENDIF

IF(MPL_OUTPUT > 1 )THEN
WRITE(MPL_UNIT,'(A,5I8)') ' MPL_ALLREDUCE ',ISENDCOUNT,ICOMM,IOPER
ENDIF
IF(PRESENT(KERROR)) THEN
KERROR=IERROR
ELSE
IF(IERROR /= 0 ) CALL MPL_MESSAGE(IERROR,'MPL_ALLREDUCE',CDSTRING,LDABORT=LLABORT)
ENDIF

IF ( MPL_NUMPROC > 1 ) &
PSENDBUF(:,:) = ZRECVBUF(:,:)

END SUBROUTINE MPL_ALLREDUCE_REAL4_2D


END MODULE MPL_ALLREDUCE_MOD


0 comments on commit d0b0384

Please sign in to comment.