Skip to content

Commit

Permalink
fix: improve modern diag manager performance (#1634)
Browse files Browse the repository at this point in the history
  • Loading branch information
uramirez8707 authored Jan 16, 2025
1 parent c105a8d commit 15ec0c7
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 51 deletions.
13 changes: 13 additions & 0 deletions diag_manager/fms_diag_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ module fms_diag_object_mod
type(fmsDiagField_type), allocatable :: FMS_diag_fields(:) !< Array of diag fields
type(fmsDiagOutputBuffer_type), allocatable :: FMS_diag_output_buffers(:) !< array of output buffer objects
!! one for each variable in the diag_table.yaml
logical, private :: data_was_send !< True if send_data has been successfully called for at least one variable
!< diag_send_complete does nothing if it is .false.
integer, private :: registered_buffers = 0 !< number of registered buffers, per dimension
class(fmsDiagAxisContainer_type), allocatable :: diag_axis(:) !< Array of diag_axis
integer, private :: registered_variables !< Number of registered variables
Expand Down Expand Up @@ -144,6 +146,7 @@ subroutine fms_diag_object_init (this,diag_subset_output, time_init)
this%buffers_initialized =fms_diag_output_buffer_init(this%FMS_diag_output_buffers,SIZE(diag_yaml%get_diag_fields()))
this%registered_variables = 0
this%registered_axis = 0
this%data_was_send = .false.
this%initialized = .true.
#else
call mpp_error("fms_diag_object_init",&
Expand Down Expand Up @@ -657,6 +660,8 @@ subroutine fms_diag_accept_data (this, diag_field_id, field_data, mask, rmask, &
main_if: if (buffer_the_data) then
!> Only 1 thread allocates the output buffer and sets set_math_needs_to_be_done
!$omp critical
!< Let diag_send_complete that there is new data to procress
if (.not. this%data_was_send) this%data_was_send = .true.

!< These set_* calls need to be done inside an omp_critical to avoid any race conditions
!! and allocation issues
Expand Down Expand Up @@ -686,6 +691,9 @@ subroutine fms_diag_accept_data (this, diag_field_id, field_data, mask, rmask, &
is, js, ks, ie, je, ke)
else

!< Let diag_send_complete that there is new data to procress
if (.not. this%data_was_send) this%data_was_send = .true.

!< At this point if we are no longer in an openmp region or running with 1 thread
!! so it is safe to have these set_* calls
if(has_halos) call this%FMS_diag_fields(diag_field_id)%set_halo_present()
Expand Down Expand Up @@ -783,8 +791,13 @@ subroutine fms_diag_send_complete(this, time_step)
#ifndef use_yaml
CALL MPP_ERROR(FATAL,"You can not use the modern diag manager without compiling with -Duse_yaml")
#else
!< Go away if there is no new data
if (.not. this%data_was_send) return

call this%do_buffer_math()
call this%fms_diag_do_io()

this%data_was_send = .false.
#endif

end subroutine fms_diag_send_complete
Expand Down
21 changes: 21 additions & 0 deletions diag_manager/fms_diag_reduction_methods.F90
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ module fms_diag_reduction_methods_mod
module procedure sum_update_done_r4, sum_update_done_r8
end interface

!> @brief Updates the buffer for any reductions that involve summation
!! (ie. time_sum, avg, rms, pow)
!! In this case the mask is present
interface sum_mask
module procedure sum_mask_r4, sum_mask_r8
end interface

!> @brief Updates the buffer for any reductions that involve summation
!! (ie. time_sum, avg, rms, pow)
!! In this case the mask is present and it varies over time
interface sum_mask_variant
module procedure sum_mask_variant_r4, sum_mask_variant_r8
end interface sum_mask_variant

!> @brief Updates the buffer for any reductions that involve summation
!! (ie. time_sum, avg, rms, pow)
!! In this case the mask is not present
interface sum_no_mask
module procedure sum_no_mask_r4, sum_no_mask_r8
end interface sum_no_mask

contains

!> @brief Checks improper combinations of is, ie, js, and je.
Expand Down
234 changes: 185 additions & 49 deletions diag_manager/include/fms_diag_reduction_methods.inc
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,7 @@ subroutine DO_TIME_SUM_UPDATE_(data_out, weight_sum, data_in, mask, is_masked, m
integer ,optional, intent(in) :: pow !< Used for pow(er) reduction,
!! calculates field_data^pow before adding to buffer

integer :: is_in, ie_in, js_in, je_in, ks_in, ke_in !< Starting and ending indices of each dimention for
!! the input buffer
integer :: is_out, ie_out, js_out, je_out, ks_out, ke_out !< Starting and ending indices of each dimention for
!! the output buffer
integer :: i, j, k, l !< For looping
real(FMS_TRM_KIND_) :: weight_scale !< local copy of optional weight
integer :: pow_loc !> local copy of optional pow value (set if using pow reduction)
integer, parameter :: kindl = FMS_TRM_KIND_ !< real kind size as set by macro
integer :: diurnal !< diurnal index to indicate which daily section is updated
!! will be 1 unless using a diurnal reduction
Expand All @@ -252,18 +246,49 @@ subroutine DO_TIME_SUM_UPDATE_(data_out, weight_sum, data_in, mask, is_masked, m
weight_scale = 1.0_kindl
endif

if(present(pow)) then
pow_loc = pow
else
pow_loc = 1.0_kindl
endif

if(diurnal_section .lt. 0) then
diurnal = 1
else
diurnal = diurnal_section
endif

if (is_masked) then
if (mask_variant) then
! Mask changes over time so the weight is an array
call sum_mask_variant(data_out, data_in, weight_sum, bounds_in, bounds_out, mask, diurnal, weight_scale, pow)
else
call sum_mask(data_out, data_in, weight_sum, bounds_in, bounds_out, mask, diurnal, &
missing_value, weight_scale, pow)
endif
else
call sum_no_mask(data_out, data_in, weight_sum, bounds_in, bounds_out, diurnal, weight_scale, pow)
endif
end subroutine DO_TIME_SUM_UPDATE_

subroutine SUM_MASK_(data_out, data_in, weight_sum, bounds_in, bounds_out, mask, diurnal, missing_value, &
weight_scale, pow)
real(FMS_TRM_KIND_), intent(inout) :: data_out(:,:,:,:,:) !< output data
real(FMS_TRM_KIND_), intent(in) :: data_in(:,:,:,:) !< data to update the buffer with
real(r8_kind), intent(inout) :: weight_sum(:,:,:,:) !< Sum of weights from the output buffer object
type(fmsDiagIbounds_type), intent(in) :: bounds_in !< indices indicating the correct portion
!! of the input buffer
type(fmsDiagIbounds_type), intent(in) :: bounds_out !< indices indicating the correct portion
!! of the output buffer
logical, intent(in) :: mask(:,:,:,:) !< mask
integer, intent(in) :: diurnal !< diurnal index to indicate which daily section is
!! updated will be 1 unless using a diurnal reduction
real(FMS_TRM_KIND_), intent(in) :: missing_value !< Missing_value for data points that are masked
real(FMS_TRM_KIND_), intent(in) :: weight_scale !< weight scale to use
integer ,optional, intent(in) :: pow !< Used for pow(er) reduction,
!! calculates field_data^pow before adding to buffer

integer :: is_in, ie_in, js_in, je_in, ks_in, ke_in !< Starting and ending indices of each dimention for
!! the input buffer
integer :: is_out, ie_out, js_out, je_out, ks_out, ke_out !< Starting and ending indices of each dimention for
!! the output buffer
integer :: pow_loc !> local copy of optional pow value (set if using pow reduction)
integer :: i, j, k, l !< For looping

is_out = bounds_out%get_imin()
ie_out = bounds_out%get_imax()
js_out = bounds_out%get_jmin()
Expand All @@ -278,56 +303,167 @@ subroutine DO_TIME_SUM_UPDATE_(data_out, weight_sum, data_in, mask, is_masked, m
ks_in = bounds_in%get_kmin()
ke_in = bounds_in%get_kmax()

!> Seperated this loops for performance. If is_masked = .false. (i.e "mask" and "rmask" were never passed in)
!! then mask will always be .True. so the if (mask) is redudant.
! TODO check if performance gain by not doing weight and pow if not needed
if (is_masked) then
if (mask_variant) then
! Mask changes over time so the weight is an array
do k = 0, ke_out - ks_out
do j = 0, je_out - js_out
do i = 0, ie_out - is_out
where (mask(is_in + i, js_in + j, ks_in + k, :))
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = &
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow_loc
!Increase the weight sum for the grid point that was not masked
weight_sum(is_out + i, js_out + j, ks_out + k, :) = &
weight_sum(is_out + i, js_out + j, ks_out + k, :) + weight_scale
endwhere
enddo
weight_sum = weight_sum + weight_scale
if (present(pow)) then
do k = 0, ke_out - ks_out
do j = 0, je_out - js_out
do i = 0, ie_out - is_out
where (mask(is_in + i, js_in + j, ks_in + k, :))
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = &
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow
elsewhere
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = missing_value
endwhere
enddo
enddo
else
weight_sum = weight_sum + weight_scale
do k = 0, ke_out - ks_out
do j = 0, je_out - js_out
do i = 0, ie_out - is_out
where (mask(is_in + i, js_in + j, ks_in + k, :))
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = &
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow_loc
elsewhere
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = missing_value
endwhere
enddo
enddo
else
do k = 0, ke_out - ks_out
do j = 0, je_out - js_out
do i = 0, ie_out - is_out
where (mask(is_in + i, js_in + j, ks_in + k, :))
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = &
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale)
elsewhere
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = missing_value
endwhere
enddo
enddo
endif
enddo
endif
end subroutine SUM_MASK_

subroutine SUM_MASK_VARIANT_(data_out, data_in, weight_sum, bounds_in, bounds_out, mask, diurnal, weight_scale, pow)
real(FMS_TRM_KIND_), intent(inout) :: data_out(:,:,:,:,:) !< output data
real(FMS_TRM_KIND_), intent(in) :: data_in(:,:,:,:) !< data to update the buffer with
real(r8_kind), intent(inout) :: weight_sum(:,:,:,:) !< Sum of weights from the output buffer object
type(fmsDiagIbounds_type), intent(in) :: bounds_in !< indices indicating the correct portion
!! of the input buffer
type(fmsDiagIbounds_type), intent(in) :: bounds_out !< indices indicating the correct portion
!! of the output buffer
logical, intent(in) :: mask(:,:,:,:) !< mask
integer, intent(in) :: diurnal !< diurnal index to indicate which daily section is
!! updated will be 1 unless using a diurnal reduction
real(FMS_TRM_KIND_), intent(in) :: weight_scale !< weight scale to use
integer ,optional, intent(in) :: pow !< Used for pow(er) reduction,
!! calculates field_data^pow before adding to buffer

integer :: is_in, ie_in, js_in, je_in, ks_in, ke_in !< Starting and ending indices of each dimention for
!! the input buffer
integer :: is_out, ie_out, js_out, je_out, ks_out, ke_out !< Starting and ending indices of each dimention for
!! the output buffer
integer :: pow_loc !> local copy of optional pow value (set if using pow reduction)
integer :: i, j, k, l !< For looping

is_out = bounds_out%get_imin()
ie_out = bounds_out%get_imax()
js_out = bounds_out%get_jmin()
je_out = bounds_out%get_jmax()
ks_out = bounds_out%get_kmin()
ke_out = bounds_out%get_kmax()

is_in = bounds_in%get_imin()
ie_in = bounds_in%get_imax()
js_in = bounds_in%get_jmin()
je_in = bounds_in%get_jmax()
ks_in = bounds_in%get_kmin()
ke_in = bounds_in%get_kmax()

if (present(pow)) then
do k = 0, ke_out - ks_out
do j = 0, je_out - js_out
do i = 0, ie_out - is_out
where (mask(is_in + i, js_in + j, ks_in + k, :))
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = &
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow

!Increase the weight sum for the grid point that was not masked
weight_sum(is_out + i, js_out + j, ks_out + k, :) = &
weight_sum(is_out + i, js_out + j, ks_out + k, :) + weight_scale
endwhere
enddo
enddo
enddo
else
weight_sum = weight_sum + weight_scale
! doesn't need to loop through l if no mask, just sums the 1d slices
do k = 0, ke_out - ks_out
do j = 0, je_out - js_out
do i = 0, ie_out - is_out
where (mask(is_in + i, js_in + j, ks_in + k, :))
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = &
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale)

!Increase the weight sum for the grid point that was not masked
weight_sum(is_out + i, js_out + j, ks_out + k, :) = &
weight_sum(is_out + i, js_out + j, ks_out + k, :) + weight_scale
endwhere
enddo
enddo
enddo
endif
end subroutine SUM_MASK_VARIANT_

subroutine SUM_NO_MASK_(data_out, data_in, weight_sum, bounds_in, bounds_out, diurnal, weight_scale, pow)
real(FMS_TRM_KIND_), intent(inout) :: data_out(:,:,:,:,:) !< output data
real(FMS_TRM_KIND_), intent(in) :: data_in(:,:,:,:) !< data to update the buffer with
real(r8_kind), intent(inout) :: weight_sum(:,:,:,:) !< Sum of weights from the output buffer object
type(fmsDiagIbounds_type), intent(in) :: bounds_in !< indices indicating the correct portion
!! of the input buffer
type(fmsDiagIbounds_type), intent(in) :: bounds_out !< indices indicating the correct portion
!! of the output buffer
integer, intent(in) :: diurnal !< diurnal index to indicate which daily section is
!! updated will be 1 unless using a diurnal reduction
real(FMS_TRM_KIND_), intent(in) :: weight_scale !< weight scale to use
integer ,optional, intent(in) :: pow !< Used for pow(er) reduction,
!! calculates field_data^pow before adding to buffer

integer :: is_in, ie_in, js_in, je_in, ks_in, ke_in !< Starting and ending indices of each dimention for
!! the input buffer
integer :: is_out, ie_out, js_out, je_out, ks_out, ke_out !< Starting and ending indices of each dimention for
!! the output buffer
integer :: i, j, k, l !< For looping

is_out = bounds_out%get_imin()
ie_out = bounds_out%get_imax()
js_out = bounds_out%get_jmin()
je_out = bounds_out%get_jmax()
ks_out = bounds_out%get_kmin()
ke_out = bounds_out%get_kmax()

is_in = bounds_in%get_imin()
ie_in = bounds_in%get_imax()
js_in = bounds_in%get_jmin()
je_in = bounds_in%get_jmax()
ks_in = bounds_in%get_kmin()
ke_in = bounds_in%get_kmax()

weight_sum = weight_sum + weight_scale

if (present(pow)) then
do k = 0, ke_out - ks_out
do j = 0, je_out - js_out
do i = 0, ie_out - is_out
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = &
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow_loc
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow
enddo
enddo
enddo
else
do k = 0, ke_out - ks_out
do j = 0, je_out - js_out
do i = 0, ie_out - is_out
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = &
data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) &
+ (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale)
enddo
enddo
enddo
endif
end subroutine DO_TIME_SUM_UPDATE_
end subroutine SUM_NO_MASK_

!> To be called with diag_send_complete, finishes reductions
!! Just divides the buffer by the counter array(which is just the sum of the weights used in the buffer's reduction)
Expand Down
11 changes: 10 additions & 1 deletion diag_manager/include/fms_diag_reduction_methods_r4.fh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,16 @@
#undef SUM_UPDATE_DONE_
#define SUM_UPDATE_DONE_ sum_update_done_r4

#undef SUM_MASK_
#define SUM_MASK_ sum_mask_r4

#undef SUM_NO_MASK_
#define SUM_NO_MASK_ sum_no_mask_r4

#undef SUM_MASK_VARIANT_
#define SUM_MASK_VARIANT_ sum_mask_variant_r4

#include "fms_diag_reduction_methods.inc"

!> @}
! close documentation grouping
! close documentation grouping
11 changes: 10 additions & 1 deletion diag_manager/include/fms_diag_reduction_methods_r8.fh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,16 @@
#undef SUM_UPDATE_DONE_
#define SUM_UPDATE_DONE_ sum_update_done_r8

#undef SUM_MASK_
#define SUM_MASK_ sum_mask_r8

#undef SUM_NO_MASK_
#define SUM_NO_MASK_ sum_no_mask_r8

#undef SUM_MASK_VARIANT_
#define SUM_MASK_VARIANT_ sum_mask_variant_r8

#include "fms_diag_reduction_methods.inc"

!> @}
! close documentation grouping
! close documentation grouping

0 comments on commit 15ec0c7

Please sign in to comment.