Skip to content

Commit

Permalink
simplify reduction done and add todos
Browse files Browse the repository at this point in the history
  • Loading branch information
rem1776 authored and rem1776 committed Nov 29, 2023
1 parent 3a8a4fd commit 3a9997b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 89 deletions.
41 changes: 9 additions & 32 deletions diag_manager/fms_diag_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,8 @@ end subroutine fms_diag_send_complete

!> @brief Loops through all the files, open the file, writes out axis and
!! variable metadata and data when necessary.
!! TODO: passing in the saved mask from the field obj to diag_reduction_done_wrapper
!! for performance
subroutine fms_diag_do_io(this, is_end_of_run)
class(fmsDiagObject_type), target, intent(inout) :: this !< The diag object
logical, optional, intent(in) :: is_end_of_run !< If .true. this is the end of the run,
Expand All @@ -722,8 +724,7 @@ subroutine fms_diag_do_io(this, is_end_of_run)
integer :: ibuff, mask_zbounds(2), mask_shape(4)
logical :: file_is_opened_this_time_step !< True if the file was opened in this time_step
!! If true the metadata will need to be written
logical :: force_write, is_writing, subregional, has_halo
logical, allocatable :: mask_adj(:,:,:,:), mask_tmp(:,:,:,:) !< copy of field mask and ajusted mask
logical :: force_write, is_writing, has_mask
logical, parameter :: DEBUG_REDUCT = .false.
class(*), allocatable :: missing_val
real(r8_kind) :: mval
Expand Down Expand Up @@ -766,37 +767,13 @@ subroutine fms_diag_do_io(this, is_end_of_run)
if( field_yaml%has_var_reduction()) then
if( field_yaml%get_var_reduction() .ge. time_average) then
if(DEBUG_REDUCT) call mpp_error(NOTE, "fms_diag_do_io:: finishing reduction for "//diag_field%get_longname())
subregional = diag_file%FMS_diag_file%has_file_sub_region()
has_halo = diag_field%is_halo_present()
! if no mask just go for it
mask: if(.not. diag_field%is_mask_variant()) then
error_string = diag_buff%diag_reduction_done_wrapper( &
has_mask = diag_field%has_mask_variant()
if(has_mask) has_mask = diag_field%get_mask_variant()
!! TODO pass in entire mask with anything needed for adjusting/grabbing the right region to
!! match output buffer
error_string = diag_buff%diag_reduction_done_wrapper( &
field_yaml%get_var_reduction(), &
mval, subregional, has_halo)
! if mask, need to check if zbounds as well for adjustment
else
zbounds: if(.not. field_yaml%has_var_zbounds()) then
! mask and no z-bounds, send mask as is
error_string = diag_buff%diag_reduction_done_wrapper( &
field_yaml%get_var_reduction(), &
mval, subregional, has_halo, &
mask=diag_field%get_mask())
else
! mask and zbounds, needs to adjust mask
mask_zbounds = field_yaml%get_var_zbounds()
mask_shape = diag_buff%get_buffer_dims()
mask_tmp = diag_field%get_mask()
! copy of masks are starting from one, potentially could be an issue with weirder masks
allocate(mask_adj(mask_shape(1), mask_shape(2), mask_zbounds(1):mask_zbounds(2), mask_shape(4)))
mask_adj(:,:,:,:) = mask_tmp(1:mask_shape(1), 1:mask_shape(2), mask_zbounds(1):mask_zbounds(2), &
1:mask_shape(4))
error_string = diag_buff%diag_reduction_done_wrapper( &
field_yaml%get_var_reduction(), &
mval, subregional, has_halo, &
mask=mask_adj)
deallocate(mask_tmp, mask_adj)
endif zbounds
endif mask
mval, has_mask)
endif
endif
!endif
Expand Down
73 changes: 25 additions & 48 deletions diag_manager/fms_diag_output_buffer.F90
Original file line number Diff line number Diff line change
Expand Up @@ -591,70 +591,47 @@ function do_time_sum_wrapper(this, field_data, mask, is_masked, bounds_in, bound
end function do_time_sum_wrapper

!> Finishes calculations for any reductions that use an average (avg, rms, pow)
function diag_reduction_done_wrapper(this, reduction_method, missing_value, is_subregional, has_halo, mask) &
!! TODO add mask and any other needed args for adjustment, and pass in the adjusted mask
!! to time_update_done
function diag_reduction_done_wrapper(this, reduction_method, missing_value, has_mask) & !! , has_halo, mask) &
result(err_msg)
class(fmsDiagOutputBuffer_type), intent(inout) :: this !< Updated buffer object
integer, intent(in) :: reduction_method !< enumerated reduction type from diag_data
real(kind=r8_kind), intent(in) :: missing_value !< missing_value for masked data points
logical, intent(in) :: is_subregional !< if subregional output (TODO can prob be removed)
logical, intent(in) :: has_halo !< true if halo region is being used
logical, optional, intent(in) :: mask(:,:,:,:) !< whether a mask variant reduction
logical, intent(in) :: has_mask !< indicates if there was a mask used during buffer updates
character(len=51) :: err_msg !< error message to return, blank if sucessful
logical, allocatable :: mask_tmp(:,:,:,:)
integer :: is, ie, js, je, ks, ke, zs, ze
!logical, intent(in) :: is_subregional !< if subregional output
!logical, intent(in) :: has_halo !< true if halo region is being used
!logical, optional, intent(in) :: mask(:,:,:,:) !< whether a mask variant reduction
!logical, allocatable :: mask_tmp(:,:,:,:)
!integer :: is, ie, js, je, ks, ke, zs, ze
!integer :: i, halo_size(4)

if(.not. allocated(this%buffer)) return

if(this%weight_sum .eq. 0.0_r8_kind) return

! TODO mask adjustment for halos, not needed unless were passing in the mask
! if the mask is stil bigger than the buffer, theres a halo region we can leave out
if(has_halo .and. present(mask)) then
is = lbound(this%buffer,1); ie = ubound(this%buffer,1)
js = lbound(this%buffer,2); je = ubound(this%buffer,2)
ks = lbound(this%buffer,3); ke = ubound(this%buffer,3)
zs = lbound(this%buffer,4); ze = ubound(this%buffer,4)
allocate(mask_tmp(is:ie,js:je,ks:ke,zs:ze))
mask_tmp = .true.
! TODO this is basically creating a new mask instead of adjusting the original one
! not ideal, only needed for mask+halo cases
select type(buff => this%buffer)
type is(real(r8_kind))
where(buff(:,:,:,:,1) .eq. missing_value)
mask_tmp(:,:,:,:) = .false.
endwhere
type is(real(r4_kind))
where(buff(:,:,:,:,1) .eq. missing_value)
mask_tmp(:,:,:,:) = .false.
endwhere
end select
!mask_tmp(is:ie,js:je,ks:ke,zs:ze) = mask(is:ie,js:je,ks:ke,zs:ze)
!print *, "adjusted mask bounds:", is, ie, js, je, ks, ke, zs, ze, "all mask_tmp, mask", all(mask_tmp), all(mask)
endif
!if(has_halo .and. present(mask)) then
!is = lbound(this%buffer,1); ie = ubound(this%buffer,1)
!js = lbound(this%buffer,2); je = ubound(this%buffer,2)
!ks = lbound(this%buffer,3); ke = ubound(this%buffer,3)
!zs = lbound(this%buffer,4); ze = ubound(this%buffer,4)
!! might be safe to assume these are all the same
!do i=1, 4
!halo_size(i) = (SIZE(this%buffer,i) - SIZE(mask,i)) / 2
!enddo
!mask_tmp = mask(is+halo_size(1):ie+halo_size(1), js+halo_size(2):je+halo_size(2), ks+halo_size(3):ke+halo_size(3),&
!zs+halo_size(4):ze+halo_size(4))
!endif

err_msg = ""
select type(buff => this%buffer)
type is (real(r8_kind))
if(present(mask)) then
! call with adjusted mask if halo
if(has_halo) then
call time_update_done(buff, this%weight_sum, reduction_method, missing_value, mask_tmp)
else
call time_update_done(buff, this%weight_sum, reduction_method, missing_value, mask)
endif
else
call time_update_done(buff, this%weight_sum, reduction_method, missing_value)
endif
call time_update_done(buff, this%weight_sum, reduction_method, missing_value, has_mask)
type is (real(r4_kind))
if(present(mask)) then
! call with adjusted mask if halo
if(has_halo) then
call time_update_done(buff, this%weight_sum, reduction_method, real(missing_value, r4_kind), mask_tmp)
else
call time_update_done(buff, this%weight_sum, reduction_method, real(missing_value, r4_kind), mask)
endif
else
call time_update_done(buff, this%weight_sum, reduction_method, real(missing_value, r4_kind))
endif
call time_update_done(buff, this%weight_sum, reduction_method, real(missing_value, r4_kind), has_mask)
end select
this%weight_sum = 0.0_r8_kind

Expand Down
16 changes: 7 additions & 9 deletions diag_manager/include/fms_diag_reduction_methods.inc
Original file line number Diff line number Diff line change
Expand Up @@ -300,23 +300,21 @@ end subroutine DO_TIME_SUM_UPDATE_

!> To be called with diag_send_complete, finishes reductions
!! Just divides the buffer by the counter array(which is just the sum of the weights used in the buffer's reduction)
subroutine SUM_UPDATE_DONE_(out_buffer_data, weight_sum, reduction_method, missing_val, mask)
!! TODO: change has_mask to an actual logical mask so we don't have to check for missing values
subroutine SUM_UPDATE_DONE_(out_buffer_data, weight_sum, reduction_method, missing_val, has_mask)
real(FMS_TRM_KIND_), intent(inout) :: out_buffer_data(:,:,:,:,:) !< data buffer previosuly updated with do_time_sum_update
real(r8_kind), intent(in) :: weight_sum !< sum of weights for averaging, provided via argument to send data
integer, intent(in) :: reduction_method !< which reduction method to use, should be time_avg
real(FMS_TRM_KIND_), intent(in) :: missing_val !< missing value for masked elements
logical, optional, intent(in) :: mask(:,:,:,:) !< logical mask from accept data call, if using one
logical :: has_mask !< whether or not mask is present
integer, parameter :: kindl = FMS_TRM_KIND_
has_mask = present(mask)
logical, intent(in) :: has_mask !< indicates if mask is used so missing values can be skipped from avg'ing
!! TODO replace conditional in the `where` with passed in and ajusted mask from the original call
!logical, optional, intent(in) :: mask(:,:,:,:) !< logical mask from accept data call, if using one.
!logical :: has_mask !< whether or not mask is present

if ( has_mask ) then
where(mask(:,:,:,:))
where(out_buffer_data(:,:,:,:,1) .ne. missing_val)
out_buffer_data(:,:,:,:,1) = out_buffer_data(:,:,:,:,1) &
/ weight_sum
elsewhere
out_buffer_data(:,:,:,:,1) = missing_val
endwhere
else !not mask variant
out_buffer_data(:,:,:,:,1) = out_buffer_data(:,:,:,:,1) &
Expand Down

0 comments on commit 3a9997b

Please sign in to comment.