Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Nov 9, 2023
1 parent e510e90 commit a79f840
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions include/mscclpp/packet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#ifndef MSCCLPP_PACKET_HPP_
#define MSCCLPP_PACKET_HPP_

#include "atomic.hpp"
#include "poll_device.hpp"

namespace mscclpp {
Expand All @@ -28,32 +29,34 @@ union alignas(16) LLPacket {
/// @param val2 The second 4-byte data to write.
/// @param flag The flag to write.
MSCCLPP_DEVICE_INLINE void write(uint32_t val1, uint32_t val2, uint32_t flag) {
// Do not directly write on `raw_` to make sure that this is interpreted as two 8-byte writes,
// not four 4-byte writes.
uint4 reg = make_uint4(val1, flag, val2, flag);
raw_ = *reinterpret_cast<ulonglong2*>(&reg);
ulonglong2* p = reinterpret_cast<ulonglong2*>(&reg);
atomicStore(&(raw_.x), p->x, memoryOrderRelaxed);
atomicStore(&(raw_.y), p->y, memoryOrderRelaxed);
}

/// Write 8 bytes of data to the packet.
/// @param val The 8-byte data to write.
/// @param flag The flag to write.
MSCCLPP_DEVICE_INLINE void write(uint64_t val, uint32_t flag) {
// Do not directly write on `raw_` to make sure that this is interpreted as two 8-byte writes,
// not four 4-byte writes.
uint4 reg = make_uint4((uint32_t)val, flag, (uint32_t)(val >> 32), flag);
raw_ = *reinterpret_cast<ulonglong2*>(&reg);
ulonglong2* p = reinterpret_cast<ulonglong2*>(&reg);
atomicStore(&(raw_.x), p->x, memoryOrderRelaxed);
atomicStore(&(raw_.y), p->y, memoryOrderRelaxed);
}

/// Helper of @ref read().
/// @param flag The flag to read.
/// @param data The 8-byte data read.
/// @return True if the flag is not equal to the given flag.
MSCCLPP_DEVICE_INLINE bool readOnce(uint32_t flag, uint2& data) const {
ulonglong2 reg = raw_;
ulonglong2 reg;
reg.x = atomicLoad(&(raw_.x), memoryOrderRelaxed);
reg.y = atomicLoad(&(raw_.y), memoryOrderRelaxed);
uint4* ptr = reinterpret_cast<uint4*>(&reg);
data.x = ptr->w;
data.y = ptr->y;
return (ptr->x != flag) || (ptr->z != flag);
data.x = ptr->x;
data.y = ptr->z;
return (ptr->y != flag) || (ptr->w != flag);
}

/// Read 8 bytes of data from the packet.
Expand Down

0 comments on commit a79f840

Please sign in to comment.