Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FALCON-40B Inference-Kernel Support #3656

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
int max_out_tokens,
bool multi_query)
{
constexpr int T_per_thread = granularity / sizeof(T);
constexpr int heads_per_block = rot_half::threads / threadsPerHead;
Expand All @@ -43,7 +44,8 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
const int head_idx = blockIdx.x * heads_per_block + threadIdx.x / threadsPerHead;
const int cur_seq_idx = head_idx % seq_len;
const int offset = head_idx * head_size;
const int k_offset = (cur_seq_idx + (head_idx / seq_len) * max_out_tokens) * head_size;
const int k_offset =
multi_query ? offset : (cur_seq_idx + (head_idx / seq_len) * max_out_tokens) * head_size;

const int seq_idx = cur_seq_idx + seq_offset;
const int half_dim = rotary_dim >> 1;
Expand Down Expand Up @@ -86,16 +88,17 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
}
}

#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \
apply_rotary_pos_half<T, HEAD_THREADS, ALIGNMENT><<<grid, block, 0, stream>>>(mixed_query, \
key_layer, \
rotary_dim, \
seq_len, \
offset, \
num_heads, \
head_size, \
total_count, \
max_out_tokens);
#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \
apply_rotary_pos_half<T, HEAD_THREADS, ALIGNMENT><<<grid, block, 0, stream>>>(mixed_query, \
key_layer, \
rotary_dim, \
seq_len, \
offset, \
num_heads, \
head_size, \
total_count, \
max_out_tokens, \
multi_query);

#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \
Expand Down Expand Up @@ -137,7 +140,8 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned num_heads,
unsigned batch,
cudaStream_t stream,
int max_out_tokens)
int max_out_tokens,
bool multi_query)
{
const int half_dim = rotary_dim >> 1;

Expand Down Expand Up @@ -176,9 +180,18 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
}
}

#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \
template void launch_apply_rotary_pos_emb<T>( \
T*, T*, unsigned, unsigned, unsigned, unsigned, unsigned, unsigned, cudaStream_t, int);
#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \
template void launch_apply_rotary_pos_emb<T>(T*, \
T*, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
cudaStream_t, \
int, \
bool);

INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float);
#ifdef BF16_AVAILABLE
Expand Down
58 changes: 39 additions & 19 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,22 @@ __global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int in
T data[values_per_access];
T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(
data_bias, bias + (offset % intermediate_size), bias != nullptr);
if (bias) {
mem_access::load_global<granularity>(
data_bias, bias + (offset % intermediate_size), bias != nullptr);

#pragma unroll
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
float bias_f = conversion::to<float>(data_bias[i]);
data[i] = conversion::to<T>(gelu(data_f + bias_f));
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
float bias_f = conversion::to<float>(data_bias[i]);
data[i] = conversion::to<T>(gelu(data_f + bias_f));
}
} else {
#pragma unroll
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
data[i] = conversion::to<T>(gelu(data_f));
}
}

mem_access::store_global<granularity>(input + offset, data);
Expand Down Expand Up @@ -321,11 +329,17 @@ __global__ void gptj_residual_add(float* residual,
res_fl4.z += attn_bias_fl4.z;
res_fl4.w += attn_bias_fl4.w;
}
if (bias) {
res_fl4.x += bias_fl4.x;
res_fl4.y += bias_fl4.y;
res_fl4.z += bias_fl4.z;
res_fl4.w += bias_fl4.w;
}
// residual = hidden_state + attention + (residual + bias) * mp_scale
res_fl4.x = hs_fl4.x + attn_fl4.x + (res_fl4.x + bias_fl4.x) * mp_scale;
res_fl4.y = hs_fl4.y + attn_fl4.y + (res_fl4.y + bias_fl4.y) * mp_scale;
res_fl4.z = hs_fl4.z + attn_fl4.z + (res_fl4.z + bias_fl4.z) * mp_scale;
res_fl4.w = hs_fl4.w + attn_fl4.w + (res_fl4.w + bias_fl4.w) * mp_scale;
res_fl4.x = hs_fl4.x + attn_fl4.x + res_fl4.x * mp_scale;
res_fl4.y = hs_fl4.y + attn_fl4.y + res_fl4.y * mp_scale;
res_fl4.z = hs_fl4.z + attn_fl4.z + res_fl4.z * mp_scale;
res_fl4.w = hs_fl4.w + attn_fl4.w + res_fl4.w * mp_scale;

res_fl4_ptr[offset] = res_fl4;
}
Expand Down Expand Up @@ -354,12 +368,10 @@ __global__ void gptj_residual_add(T* residual,
float2 res_fl2 = res_fl2_ptr[offset];
const float2 hs_fl2 = hs_fl2_ptr[offset];
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];

T2* res_half2 = reinterpret_cast<T2*>(&res_fl2);
const T2* hs_half2 = reinterpret_cast<const T2*>(&hs_fl2);
const T2* attn_half2 = reinterpret_cast<const T2*>(&attn_fl2);
const T2* bias_half2 = reinterpret_cast<const T2*>(&bias_fl2);

float2 res_low = conversion::to<float2>(res_half2[0]);
float2 res_high = conversion::to<float2>(res_half2[1]);
Expand All @@ -370,9 +382,6 @@ __global__ void gptj_residual_add(T* residual,
const float2 attn_low = conversion::to<float2>(attn_half2[0]);
const float2 attn_high = conversion::to<float2>(attn_half2[1]);

const float2 bias_low = conversion::to<float2>(bias_half2[0]);
const float2 bias_high = conversion::to<float2>(bias_half2[1]);

if (attn_bias) {
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
const T2* attn_bias_half2 = reinterpret_cast<const T2*>(&attn_bias_fl2);
Expand All @@ -384,11 +393,22 @@ __global__ void gptj_residual_add(T* residual,
res_high.x += attn_bias_high.x;
res_high.y += attn_bias_high.y;
}
if (bias) {
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
const T2* bias_half2 = reinterpret_cast<const T2*>(&bias_fl2);
const float2 bias_low = conversion::to<float2>(bias_half2[0]);
const float2 bias_high = conversion::to<float2>(bias_half2[1]);
// residual += attention_bias
res_low.x += bias_low.x;
res_low.y += bias_low.y;
res_high.x += bias_high.x;
res_high.y += bias_high.y;
}
// residual = hidden_state + attention + (residual + bias) * mp_scale
res_low.x = attn_low.x + hs_low.x + (res_low.x + bias_low.x) * mp_scale;
res_low.y = attn_low.y + hs_low.y + (res_low.y + bias_low.y) * mp_scale;
res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale;
res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale;
res_low.x = attn_low.x + hs_low.x + res_low.x * mp_scale;
res_low.y = attn_low.y + hs_low.y + res_low.y * mp_scale;
res_high.x = attn_high.x + hs_high.x + res_high.x * mp_scale;
res_high.y = attn_high.y + hs_high.y + res_high.y * mp_scale;

res_half2[0] = conversion::to<T2>(res_low);
res_half2[1] = conversion::to<T2>(res_high);
Expand Down
Loading