Skip to content

Commit

Permalink
avoid mixing jinja2 and compile-time expression
Browse files Browse the repository at this point in the history
  • Loading branch information
amirakb89 authored and avbokovoy committed Feb 4, 2025
1 parent d3e63b8 commit 2eeaaf7
Showing 1 changed file with 16 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
if (D_bytes <= MinNum128BRows * 128 || D_bytes > MaxNum128BRows * 128) {
return;
}
{%- if kWarpSize==64 %}
constexpr int32_t MaxNum128BRows_req = (MaxNum128BRows+1) / 2;
{%- else %}
constexpr int32_t MaxNum128BRows_req = MaxNum128BRows;
{%- endif %}

constexpr int32_t AccumulateStoreRequests = (kWarpSize == 64) ? (MaxNum128BRows + 1) / 2 : MaxNum128BRows;
const int64_t weights_offset = weights_offsets[t];
const int32_t D_total = padded_D(D, weight_ty);
const int32_t D_padding = D_total - D;
Expand Down Expand Up @@ -125,7 +121,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
}

{% if not nobag %}
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> accumulators[OutputRowsPerThread][MaxNum128BRows_req];
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> accumulators[OutputRowsPerThread][AccumulateStoreRequests];
{% endif %}

for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) {
Expand Down Expand Up @@ -277,8 +273,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
using scalar_t = {{ emb_weight_type.cpp_type_name }};

{% if not nobag %}
#pragma unroll MaxNum128BRows_req
for (uint32_t j = 0; j < MaxNum128BRows_req; ++j) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
{% if weighted %}
accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight);
Expand All @@ -289,8 +285,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
{% else %}
const int32_t output_j = indices_starts[i] + L_start + input_row_idx;
if constexpr (std::is_same_v<output_t, float> || std::is_same_v<output_t, at::Half> || std::is_same_v<output_t, at::BFloat16>) {
#pragma unroll MaxNum128BRows_req
for (uint32_t j = 0; j < MaxNum128BRows_req; ++j) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
// Read the uint8/4/2 values: note that first 4 Bytes will be ditched later:
// We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to
// the scale/shift handling).
Expand All @@ -309,8 +305,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
auto thread_local_min = std::numeric_limits<float>::max();
auto thread_local_max = std::numeric_limits<float>::lowest();
float2 qparams;
#pragma unroll MaxNum128BRows_req
for (uint32_t j = 0; j < MaxNum128BRows_req; ++j) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %});
Expand All @@ -320,8 +316,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
}
}
qparams = warp_find_qparams(thread_local_min, thread_local_max);
#pragma unroll MaxNum128BRows_req
for (uint32_t j = 0; j < MaxNum128BRows_req; ++j) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
if (output_d >= 0 && output_d < D) {
Expand All @@ -348,8 +344,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast<float>(1.0) / Ls[i] : static_cast<float>(1.0);

if constexpr (std::is_same_v<output_t, float> || std::is_same_v<output_t, at::Half> || std::is_same_v<output_t, at::BFloat16>) {
#pragma unroll MaxNum128BRows_req
for (uint32_t j = 0; j < MaxNum128BRows_req; ++j) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
if constexpr (PackedMode) {
output_d -= packed_bag_idx * kOutputsPerThread * num_stores_with_padding_per_row;
Expand All @@ -375,8 +371,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
float thread_local_min = std::numeric_limits<float>::max();
float thread_local_max = std::numeric_limits<float>::lowest();
float2 qparams;
#pragma unroll MaxNum128BRows_req
for (uint32_t j = 0; j < MaxNum128BRows_req; ++j) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
accumulators[i][j].mul(inv_L);
if (output_d >= 0 && output_d < D) {
Expand All @@ -388,8 +384,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
qparams = warp_find_qparams(thread_local_min, thread_local_max);
const int output_D_start = D_start + t * 8;
const int output_D_end = output_D_start + D;
#pragma unroll MaxNum128BRows_req
for (uint32_t j = 0; j < MaxNum128BRows_req; ++j) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
if (output_d >= 0 && output_d < D) {
const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }}));
Expand Down

0 comments on commit 2eeaaf7

Please sign in to comment.