diff --git a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh index d2d040a0e..5fa93ee12 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh @@ -29,7 +29,7 @@ class nvshmem_device_reference { : pointer_(static_cast(nvshmem_ref.pointer)), typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT)) { - assert(gref.stride % sizeof(DataTypeT) == 0); + assert(nvshmem_ref.stride % sizeof(DataTypeT) == 0); } __device__ nvshmem_device_reference() = delete; diff --git a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu index a860cbc6c..4051f12bd 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu @@ -185,6 +185,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( p_env_fns, stream); // ungistre + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); if (nvshmemx_buffer_unregister(temp_output_ptr) != 0) { WHOLEMEMORY_ERROR("nvshmemx_buffer_unregister error in wholememory_gather_nvshmem"); }