Skip to content

Commit

Permalink
🐛 fix buggy usage of non-contiguous slabs (#451)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmd-dk authored Nov 10, 2024
1 parent 697ce40 commit feaacee
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 58 deletions.
50 changes: 7 additions & 43 deletions src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2249,7 +2249,6 @@ def domain_decompose(
grid='double[:, :, ::1]',
slab_or_buffer_name=object, # double[:, :, ::1], int or str
prepare_fft='bint',
trim='bint',
# Locals
N_domain2slabs_communications='Py_ssize_t',
buffer_name=object, # int or str
Expand All @@ -2273,7 +2272,6 @@ def domain_decompose(
should_recv='bint',
slab='double[:, :, ::1]',
slab_arr=object,
slab_maybe_trimmed='double[:, :, ::1]',
slab_sendrecv_j_end='int[::1]',
slab_sendrecv_j_start='int[::1]',
slab_sendrecv_k_end='int[::1]',
Expand All @@ -2283,7 +2281,7 @@ def domain_decompose(
='Py_ssize_t',
returns='double[:, :, ::1]',
)
def slab_decompose(grid, slab_or_buffer_name=None, prepare_fft=False, trim=False):
def slab_decompose(grid, slab_or_buffer_name=None, prepare_fft=False):
"""This function communicates a global domain decomposed grid into
a global slab decomposed grid. If an existing slab grid should be
used it can be passed as the second argument.
Expand Down Expand Up @@ -2351,14 +2349,9 @@ def slab_decompose(grid, slab_or_buffer_name=None, prepare_fft=False, trim=False
slab_sendrecv_k_start,
slab_sendrecv_k_end,
) = prepare_decomposition(grid, slab)
# Trim the slab if requested.
# All further operations work on both full and trimmed slabs.
slab_maybe_trimmed = slab
if trim:
slab_maybe_trimmed = trim_slab(slab)
# Communicate the domain grid to the slabs in chunks
n_chunks, thickness_chunk = get_slab_domain_decomposition_chunk_size(
slab_maybe_trimmed, grid_noghosts,
slab, grid_noghosts,
)
for in range(N_domain2slabs_communications):
should_send = ( < [slabs2domain_sendrecv_ranks.shape[0]])
Expand Down Expand Up @@ -2398,10 +2391,10 @@ def slab_decompose(grid, slab_or_buffer_name=None, prepare_fft=False, trim=False
if should_recv:
chunk_recv_bgn = i_chunk*thickness_chunk
chunk_recv_end = chunk_recv_bgn + thickness_chunk
if chunk_recv_end > [slab_maybe_trimmed.shape[0]]:
chunk_recv_end = [slab_maybe_trimmed.shape[0]]
if chunk_recv_end > [slab.shape[0]]:
chunk_recv_end = [slab.shape[0]]
smart_mpi(
slab_maybe_trimmed[
slab[
chunk_recv_bgn:chunk_recv_end,
[slab_sendrecv_j_start[]]:[slab_sendrecv_j_end[]],
[slab_sendrecv_k_start[]]:[slab_sendrecv_k_end[]],
Expand All @@ -2415,7 +2408,7 @@ def slab_decompose(grid, slab_or_buffer_name=None, prepare_fft=False, trim=False
# overwritten by the next (non-blocking) send.
if should_send:
request.wait()
return slab_maybe_trimmed
return slab

# Helper function for slab and domain decomposition
@cython.header(
Expand Down Expand Up @@ -3748,38 +3741,13 @@ def fill_slab_padding(slab, value):
for index, i, j, k in slab_loop(gridsize, skip_data=True):
slab_ptr[index] = value

# Function for trimming away the padded elements of slabs
@cython.header(
# Arguments
slab='double[:, :, ::1]',
# Locals
arr=object, # np.ndarray
gridsize='Py_ssize_t',
slab_trimmed='double[:, :, ::1]',
returns='double[:, :, ::1]',
)
def trim_slab(slab):
"""Note that trimmed slabs may not be used with FFTW"""
arr = asarray(slab)
shape = arr.shape
gridsize = slab.shape[1]
if shape != get_slabshape_local(gridsize):
abort(f'trim_slab() got slab of erroneous shape {shape}')
slab_trimmed = (
arr.reshape([slab.shape[0]*slab.shape[1]]*slab.shape[2])
[:[slab.shape[0]*slab.shape[1]]*[slab.shape[2] - 2]]
.reshape((slab.shape[0], slab.shape[1], [slab.shape[2] - 2]))
)
return slab_trimmed

# Function that returns a slab decomposed grid,
# allocated by FFTW.
@cython.pheader(
# Arguments
gridsize='Py_ssize_t',
buffer_name=object, # int or str or None
nullify=object, # bint, str or list of str's
trim='bint',
# Locals
acquire='bint',
as_expected='bint',
Expand All @@ -3798,16 +3766,14 @@ def trim_slab(slab):
wisdom_filename=str,
returns='double[:, :, ::1]',
)
def get_fftw_slab(gridsize, buffer_name=None, nullify=False, trim=False):
def get_fftw_slab(gridsize, buffer_name=None, nullify=False):
global fftw_plans_size, fftw_plans_forward, fftw_plans_backward
if buffer_name is None:
buffer_name = 'slab_global'
# If this slab has already been constructed, fetch it
slab = slabs.get((gridsize, buffer_name))
if slab is not None:
nullify_modes(slab, nullify)
if trim:
return trim_slab(slab)
return slab
# Checks on the passed gridsize
if gridsize%nprocs != 0:
Expand Down Expand Up @@ -3897,8 +3863,6 @@ def get_fftw_slab(gridsize, buffer_name=None, nullify=False, trim=False):
# Store and return this slab
slabs[gridsize, buffer_name] = slab
nullify_modes(slab, nullify)
if trim:
return trim_slab(slab)
return slab
# Cache storing slabs. The keys have the format (gridsize, buffer_name).
cython.declare(slabs=dict)
Expand Down
30 changes: 15 additions & 15 deletions src/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def __init__(self):
name=object, # str or int
plural=str,
shape=tuple,
slab='double[:, :, ::1]',
slab_end='Py_ssize_t',
slab_start='Py_ssize_t',
slab_trimmed='double[:, :, ::1]',
start_local='Py_ssize_t',
returns=str,
)
Expand Down Expand Up @@ -245,14 +245,14 @@ def save(self, filename, save_all=False):
# slab decomposed. Here we communicate the
# fluid scalar to slabs before saving to
# disk, improving performance enormously.
slab_trimmed = slab_decompose(fluidscalar.grid_mv, trim=True)
slab_start = [slab_trimmed.shape[0]]*rank
slab_end = slab_start + [slab_trimmed.shape[0]]
slab = slab_decompose(fluidscalar.grid_mv)
slab_start = slab.shape[0]*rank
slab_end = slab_start + slab.shape[0]
fluidscalar_h5[
slab_start:slab_end,
:,
:,
] = slab_trimmed[:, :, :]
] = slab[:, :, :(slab.shape[2] - 2)] # exclude padding
# Create additional names (hard links) for the fluid
# groups and data sets. The names from
# component.fluid_names will be used, except for
Expand Down Expand Up @@ -328,8 +328,8 @@ def save(self, filename, save_all=False):
pos='double*',
representation=str,
size='Py_ssize_t',
slab='double[:, :, ::1]',
slab_start='Py_ssize_t',
slab_trimmed='double[:, :, ::1]',
snapshot_unit_length='double',
snapshot_unit_mass='double',
snapshot_unit_time='double',
Expand Down Expand Up @@ -544,14 +544,14 @@ def load(self, filename, only_params=False):
fluidvar_h5 = component_h5[f'fluidvar_{index}']
for multi_index in fluidvar.multi_indices:
fluidscalar_h5 = fluidvar_h5[f'fluidscalar_{multi_index}']
slab_trimmed = get_fftw_slab(gridsize, trim=True)
slab_start = [slab_trimmed.shape[0]]*rank
slab = get_fftw_slab(gridsize)
slab_start = [slab.shape[0]]*rank
# Load in using chunks. Large chunks are
# fine as no temporary buffer is used. The
# maximum possible chunk size is limited
# by MPI, though.
chunk_size = np.min((
[slab_trimmed.shape[0]],
[slab.shape[0]],
[self.chunk_size_max//8//gridsize**2],
))
if chunk_size == 0:
Expand All @@ -560,10 +560,10 @@ def load(self, filename, only_params=False):
'and may not be read in correctly'
)
chunk_size = 1
arr = asarray(slab_trimmed)
for index_i in range(0, [slab_trimmed.shape[0]], chunk_size):
if index_i + chunk_size > [slab_trimmed.shape[0]]:
chunk_size = [slab_trimmed.shape[0]] - index_i
arr = asarray(slab)
for index_i in range(0, [slab.shape[0]], chunk_size):
if index_i + chunk_size > [slab.shape[0]]:
chunk_size = [slab.shape[0]] - index_i
index_i_file = slab_start + index_i
source_sel = (
slice(index_i_file, index_i_file + chunk_size),
Expand All @@ -573,15 +573,15 @@ def load(self, filename, only_params=False):
dest_sel = (
slice(index_i, index_i + chunk_size),
slice(None),
slice(None),
slice(0, slab.shape[2] - 2), # exclude padding
)
fluidscalar_h5.read_direct(
arr, source_sel=source_sel, dest_sel=dest_sel,
)
# Communicate the slabs directly to the
# domain decomposed fluid grids.
domain_decompose(
slab_trimmed,
slab,
component.fluidvars[index][multi_index].grid_mv,
do_ghost_communication=True,
)
Expand Down

0 comments on commit feaacee

Please sign in to comment.