diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index e743ffc3f64f..e7df74dceb50 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -32,7 +32,8 @@ __global__ void apply_rotary_pos_half(T* mixed_query, unsigned num_heads, unsigned head_size, unsigned total_count, - int max_out_tokens) + int max_out_tokens, + bool multi_query) { constexpr int T_per_thread = granularity / sizeof(T); constexpr int heads_per_block = rot_half::threads / threadsPerHead; @@ -43,7 +44,8 @@ __global__ void apply_rotary_pos_half(T* mixed_query, const int head_idx = blockIdx.x * heads_per_block + threadIdx.x / threadsPerHead; const int cur_seq_idx = head_idx % seq_len; const int offset = head_idx * head_size; - const int k_offset = (cur_seq_idx + (head_idx / seq_len) * max_out_tokens) * head_size; + const int k_offset = + multi_query ? offset : (cur_seq_idx + (head_idx / seq_len) * max_out_tokens) * head_size; const int seq_idx = cur_seq_idx + seq_offset; const int half_dim = rotary_dim >> 1; @@ -86,16 +88,17 @@ __global__ void apply_rotary_pos_half(T* mixed_query, } } -#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \ - apply_rotary_pos_half<<>>(mixed_query, \ - key_layer, \ - rotary_dim, \ - seq_len, \ - offset, \ - num_heads, \ - head_size, \ - total_count, \ - max_out_tokens); +#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \ + apply_rotary_pos_half<<>>(mixed_query, \ + key_layer, \ + rotary_dim, \ + seq_len, \ + offset, \ + num_heads, \ + head_size, \ + total_count, \ + max_out_tokens, \ + multi_query); #ifdef __HIP_PLATFORM_HCC__ #define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ @@ -137,7 +140,8 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned num_heads, unsigned batch, cudaStream_t stream, - int max_out_tokens) + int max_out_tokens, + bool multi_query) { const int half_dim = rotary_dim >> 1; @@ -176,9 +180,18 @@ void launch_apply_rotary_pos_emb(T* mixed_query, } } -#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ - template void launch_apply_rotary_pos_emb( \ - T*, T*, unsigned, unsigned, unsigned, unsigned, unsigned, unsigned, cudaStream_t, int); +#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ + template void launch_apply_rotary_pos_emb(T*, \ + T*, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + cudaStream_t, \ + int, \ + bool); INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float); #ifdef BF16_AVAILABLE diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 3e6701d81e64..fe3d0a30ba12 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -38,14 +38,22 @@ __global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int in T data[values_per_access]; T data_bias[values_per_access]; mem_access::load_global(data, input + offset); - mem_access::load_global( - data_bias, bias + (offset % intermediate_size), bias != nullptr); + if (bias) { + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); #pragma unroll - for (int i = 0; i < values_per_access; i++) { - float data_f = conversion::to(data[i]); - float bias_f = conversion::to(data_bias[i]); - data[i] = conversion::to(gelu(data_f + bias_f)); + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(gelu(data_f + bias_f)); + } + } else { +#pragma unroll + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + data[i] = conversion::to(gelu(data_f)); + } } mem_access::store_global(input + offset, data); @@ -321,11 +329,17 @@ __global__ void gptj_residual_add(float* residual, res_fl4.z += attn_bias_fl4.z; res_fl4.w += attn_bias_fl4.w; } + if (bias) { + res_fl4.x += bias_fl4.x; + res_fl4.y += bias_fl4.y; + res_fl4.z += bias_fl4.z; + res_fl4.w += bias_fl4.w; + } // residual = hidden_state + attention + (residual + bias) * mp_scale - res_fl4.x = hs_fl4.x + attn_fl4.x + (res_fl4.x + bias_fl4.x) * mp_scale; - res_fl4.y = hs_fl4.y + attn_fl4.y + (res_fl4.y + bias_fl4.y) * mp_scale; - res_fl4.z = hs_fl4.z + attn_fl4.z + (res_fl4.z + bias_fl4.z) * mp_scale; - res_fl4.w = hs_fl4.w + attn_fl4.w + (res_fl4.w + bias_fl4.w) * mp_scale; + res_fl4.x = hs_fl4.x + attn_fl4.x + res_fl4.x * mp_scale; + res_fl4.y = hs_fl4.y + attn_fl4.y + res_fl4.y * mp_scale; + res_fl4.z = hs_fl4.z + attn_fl4.z + res_fl4.z * mp_scale; + res_fl4.w = hs_fl4.w + attn_fl4.w + res_fl4.w * mp_scale; res_fl4_ptr[offset] = res_fl4; } @@ -354,12 +368,10 @@ __global__ void gptj_residual_add(T* residual, float2 res_fl2 = res_fl2_ptr[offset]; const float2 hs_fl2 = hs_fl2_ptr[offset]; const float2 attn_fl2 = attn_fl2_ptr[offset]; - const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; T2* res_half2 = reinterpret_cast(&res_fl2); const T2* hs_half2 = reinterpret_cast(&hs_fl2); const T2* attn_half2 = reinterpret_cast(&attn_fl2); - const T2* bias_half2 = reinterpret_cast(&bias_fl2); float2 res_low = conversion::to(res_half2[0]); float2 res_high = conversion::to(res_half2[1]); @@ -370,9 +382,6 @@ __global__ void gptj_residual_add(T* residual, const float2 attn_low = conversion::to(attn_half2[0]); const float2 attn_high = conversion::to(attn_half2[1]); - const float2 bias_low = conversion::to(bias_half2[0]); - const float2 bias_high = conversion::to(bias_half2[1]); - if (attn_bias) { const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); @@ -384,11 +393,22 @@ __global__ void gptj_residual_add(T* residual, res_high.x += attn_bias_high.x; res_high.y += attn_bias_high.y; } + if (bias) { + const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; + const T2* bias_half2 = reinterpret_cast(&bias_fl2); + const float2 bias_low = conversion::to(bias_half2[0]); + const float2 bias_high = conversion::to(bias_half2[1]); + // residual += attention_bias + res_low.x += bias_low.x; + res_low.y += bias_low.y; + res_high.x += bias_high.x; + res_high.y += bias_high.y; + } // residual = hidden_state + attention + (residual + bias) * mp_scale - res_low.x = attn_low.x + hs_low.x + (res_low.x + bias_low.x) * mp_scale; - res_low.y = attn_low.y + hs_low.y + (res_low.y + bias_low.y) * mp_scale; - res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale; - res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale; + res_low.x = attn_low.x + hs_low.x + res_low.x * mp_scale; + res_low.y = attn_low.y + hs_low.y + res_low.y * mp_scale; + res_high.x = attn_high.x + hs_high.x + res_high.x * mp_scale; + res_high.y = attn_high.y + hs_high.y + res_high.y * mp_scale; res_half2[0] = conversion::to(res_low); res_half2[1] = conversion::to(res_high); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 97db77bff8a2..a4cd30306d15 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -364,7 +364,8 @@ void attention_unfused(T* prev_key_cont, bool local_attention, int window_size, at::Tensor& alibi, - int layer_id) + int layer_id, + int kv_seq_stride) { float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; float alpha = norm_factor * norm_factor / layer_scale; @@ -384,7 +385,7 @@ void attention_unfused(T* prev_key_cont, workspace, CUBLAS_OP_T, CUBLAS_OP_N, - InferenceContext::Instance().GetMaxTokenLength() * k, + kv_seq_stride * k, seq_len * k, seq_len * soft_len, bsz * heads, @@ -417,7 +418,7 @@ void attention_unfused(T* prev_key_cont, (T*)output, CUBLAS_OP_N, CUBLAS_OP_N, - InferenceContext::Instance().GetMaxTokenLength() * k, + kv_seq_stride * k, seq_len * soft_len, seq_len * k, bsz * heads, @@ -444,11 +445,16 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, bool no_masking, unsigned layer_id, unsigned num_layers, - at::Tensor& alibi) + at::Tensor& alibi, + bool multi_query, + int num_kv) { unsigned bsz = query_key_value.size(0); unsigned seq_len = query_key_value.size(1); - unsigned hidden_dim = query_key_value.size(2) / 3; + unsigned shared_kv_heads_size = query_key_value.size(2) / (num_kv * 2 + heads); + unsigned hidden_dim = multi_query + ? query_key_value.size(2) - (num_kv * 2 * shared_kv_heads_size) + : query_key_value.size(2) / 3; bool is_prompt = (seq_len > 1); @@ -467,34 +473,57 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); auto query_cont = workspace + 5 * buf_size; + auto key_cont = workspace + 6 * buf_size; + auto value_cont = + key_cont + bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; size_t offset = 10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLength()) + layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; unsigned all_tokens = soft_len; auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); - size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; + size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLength() * + (multi_query ? (num_kv * k) : hidden_dim); + int kv_seq_stride = + (multi_query ? all_tokens : InferenceContext::Instance().GetMaxTokenLength()); T* temp_buf = (T*)output.data_ptr() + at::numel(output); - launch_bias_add_transform_0213((T*)query_cont, - kv_cache, - kv_cache + value_offset, - (T*)query_key_value.data_ptr(), - nullptr, - bsz, - seq_len, - (is_prompt ? 0 : soft_len - 1), - soft_len, - hidden_dim, - heads, - rotary_dim, - rotate_half, - rotate_every_two, - InferenceContext::Instance().GetCurrentStream(), - 3, - InferenceContext::Instance().GetMaxTokenLength()); + if (multi_query) { + launch_transform_multi_query((T*)query_cont, + (T*)key_cont, + (T*)value_cont, + workspace + offset, + workspace + offset + value_offset, + (T*)query_key_value.data_ptr(), + bsz, + seq_len, + soft_len, + hidden_dim, + heads, + num_kv, + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLength()); + } else { + launch_bias_add_transform_0213((T*)query_cont, + kv_cache, + kv_cache + value_offset, + (T*)query_key_value.data_ptr(), + nullptr, + bsz, + seq_len, + (is_prompt ? 0 : soft_len - 1), + soft_len, + hidden_dim, + heads, + rotary_dim, + rotate_half, + rotate_every_two, + InferenceContext::Instance().GetCurrentStream(), + 3, + InferenceContext::Instance().GetMaxTokenLength()); + } if (rotary_dim > 0 && rotate_half) launch_apply_rotary_pos_emb(query_cont, - kv_cache, + (multi_query ? key_cont : kv_cache), k, seq_len, rotary_dim, @@ -502,12 +531,13 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, heads, bsz, InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLength()); + kv_seq_stride, + multi_query); - attention_unfused(workspace + offset, + attention_unfused((multi_query ? key_cont : workspace + offset), (T*)query_cont, attn_mask, - workspace + offset + value_offset, + (multi_query ? value_cont : workspace + offset + value_offset), temp_buf, bsz, k, @@ -520,7 +550,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, local_attention, window_size, alibi, - layer_id); + layer_id, + kv_seq_stride); launch_transform4d_0213((T*)output.data_ptr(), temp_buf, bsz, @@ -531,22 +562,23 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, 1); if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens(); - auto prev_key = torch::from_blob(workspace + offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * InferenceContext::Instance().GetMaxTokenLength(), - k * InferenceContext::Instance().GetMaxTokenLength(), - k, - 1}, - options); - - auto prev_value = - torch::from_blob(workspace + offset + value_offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * InferenceContext::Instance().GetMaxTokenLength(), - k * InferenceContext::Instance().GetMaxTokenLength(), - k, - 1}, - options); + auto prev_key = torch::from_blob( + workspace + offset, + {bsz, (multi_query ? num_kv : heads), all_tokens, k}, + {(multi_query ? num_kv * k : hidden_dim) * InferenceContext::Instance().GetMaxTokenLength(), + k * InferenceContext::Instance().GetMaxTokenLength(), + k, + 1}, + options); + + auto prev_value = torch::from_blob( + workspace + offset + value_offset, + {bsz, (multi_query ? num_kv : heads), all_tokens, k}, + {(multi_query ? num_kv * k : hidden_dim) * InferenceContext::Instance().GetMaxTokenLength(), + k * InferenceContext::Instance().GetMaxTokenLength(), + k, + 1}, + options); return {output, prev_key, prev_value}; } @@ -903,7 +935,7 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, { int bsz = input.size(0) * input.size(1); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - workspace += (3 * bsz * input.size(2)); + workspace += (bsz * weight.size(1)); ds_layer_norm_internal(workspace, input, gamma, beta, epsilon); if (q_int8) { @@ -952,7 +984,7 @@ std::vector ds_rms_qkv(at::Tensor& input, { const int bsz = input.size(0) * input.size(1); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - T* rms_norm_ptr = workspace + (3 * bsz * input.size(2)); + T* rms_norm_ptr = workspace + (bsz * weight.size(1)); int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); auto options = at::TensorOptions() @@ -1437,13 +1469,13 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, } if (act_func_type == ActivationFuncType::GELU) { launch_bias_gelu(intermediate, - (T*)bias.data_ptr(), + (T*)(bias.size(0) > 0 ? bias.data_ptr() : nullptr), (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), bsz, InferenceContext::Instance().GetCurrentStream()); } else if (act_func_type == ActivationFuncType::ReLU) { launch_bias_relu(intermediate, - (T*)bias.data_ptr(), + (T*)(bias.size(0) > 0 ? bias.data_ptr() : nullptr), (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), bsz, InferenceContext::Instance().GetCurrentStream()); @@ -1806,7 +1838,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state, static_cast(residual.data_ptr()), static_cast(hidden_state.data_ptr()), static_cast(attention_output.data_ptr()), - static_cast(final_bias.data_ptr()), + static_cast(final_bias.size(0) > 0 ? final_bias.data_ptr() : nullptr), static_cast((add_bias ? attention_bias.data_ptr() : nullptr)), hidden_size, bsz, @@ -1862,7 +1894,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, num_heads, bsz, InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLength()); + InferenceContext::Instance().GetMaxTokenLength(), + false); else launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), (__half*)key_cont.data_ptr(), @@ -1873,7 +1906,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, num_heads, bsz, InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLength()); + InferenceContext::Instance().GetMaxTokenLength(), + false); return {query_cont, key_cont}; } diff --git a/csrc/transformer/inference/csrc/relu.cu b/csrc/transformer/inference/csrc/relu.cu index 40926b776cf2..b851883a84ec 100644 --- a/csrc/transformer/inference/csrc/relu.cu +++ b/csrc/transformer/inference/csrc/relu.cu @@ -28,14 +28,22 @@ __global__ void fused_bias_relu(T* input, const T* bias, int total_count, int in T data[values_per_access]; T data_bias[values_per_access]; mem_access::load_global(data, input + offset); - mem_access::load_global( - data_bias, bias + (offset % intermediate_size), bias != nullptr); + if (bias) { + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); #pragma unroll - for (int i = 0; i < values_per_access; i++) { - float data_f = conversion::to(data[i]); - float bias_f = conversion::to(data_bias[i]); - data[i] = conversion::to(relu(data_f + bias_f)); + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(relu(data_f + bias_f)); + } + } else { +#pragma unroll + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + data[i] = conversion::to(relu(data_f)); + } } mem_access::store_global(input + offset, data); diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 650f286a8f03..1a0bcc6e2882 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -707,3 +707,180 @@ INSTANTIATE_2B_LAUNCH_TRANSFORM4D(__half) #ifdef BF16_AVAILABLE INSTANTIATE_2B_LAUNCH_TRANSFORM4D(__nv_bfloat16) #endif + +__global__ void transform_multi_query(float* query, + float* key, + float* value, + float* k_cache, + float* v_cache, + const float* vals, + int hidden_dim, + int seq_length, + int all_tokens, + int heads, + int max_out_tokens) +{ +} + +#define ATTN_H 3 +#define MAX_SEQ_LINE 10 + +template +__global__ void transform_multi_query(T* query, + T* key, + T* value, + T* k_cache, + T* v_cache, + const T* vals, // qkv + int seq_length, + int all_tokens, + int heads, + int num_kv, + int blks, + int query_heads, + int fused_dim, + int qkv_dim, + int hidden_dim, + int max_out_tokens) +{ + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; + int d0_stride = fused_dim * seq_length; + int d1_stride = fused_dim; + int d2_stride = blockDim.x; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / blks; // kv count + int blk_count = blockIdx.z % blks; + int d2 = threadIdx.y + blk_count * blockDim.y; + int d3 = threadIdx.x; // Values (groups of 4) + + int d2_out_stride = d2_stride * seq_length; + int d0_out_stride = hidden_dim * seq_length; + + int d2_out_stride_kv = d2_stride * all_tokens; + int d0_out_stride_kv = hidden_dim * all_tokens; + + float4 vals_arr; + float4 output_arr; + + T2* vals_half = reinterpret_cast(&vals_arr); + T2* output_half = reinterpret_cast(&output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + float4* cache = reinterpret_cast(d2 < (query_heads << 1) ? k_cache : v_cache); + float4* output_vec = reinterpret_cast( + d2 < query_heads ? query : (d2 < (query_heads << 1) ? key : value)); + + vals_vec += (d0 * d0_stride); + vals_vec += (d1 * d1_stride); + vals_vec += (cnt * qkv_dim); + + if (d2 < query_heads) { + vals_vec += (d2 * d2_stride); + } else { + if (d2 < (query_heads << 1)) + vals_vec += (query_heads * d2_stride); + else + vals_vec += ((query_heads + 1) * d2_stride); + } + + output_vec += (d1 * d2_stride); + output_vec += (d0 * (d2 < query_heads ? d0_out_stride : d0_out_stride_kv)); + output_vec += (((d2 % query_heads) + cnt * query_heads) * + (d2 < query_heads ? d2_out_stride : d2_out_stride_kv)); + + if (d1 < seq_length || d2 >= query_heads) { + if (d2 < query_heads) + output_vec[d3] = vals_vec[d3]; + else { + if (d1 < seq_length) { + output_vec += (all_tokens - seq_length) * d2_stride; + output_vec[d3] = vals_vec[d3]; + } else { + cache += (d2_stride * max_out_tokens * num_kv) * d0 + d2_stride * d1 + + cnt * (d2_stride * max_out_tokens); + float4 inp = cache[d3]; + output_vec[d3] = cache[d3]; + } + } + } + + if (d1 < seq_length && (d2 == query_heads || d2 == (query_heads << 1))) { + cache += (d2_stride * max_out_tokens * num_kv) * d0 + + d2_stride * (all_tokens - seq_length + d1) + cnt * (d2_stride * max_out_tokens); + cache[d3] = vals_vec[d3]; + } +} + +// [B S C*H] - > C * [B A S N] +template <> +void launch_transform_multi_query(float* query, + float* key, + float* value, + float* k_cache, + float* v_cache, + const float* vals, + int batch_size, + int seq_length, + int all_tokens, + int hidden_dim, + int heads, + int num_kv, + cudaStream_t stream, + int max_out_tokens) +{ +} +template +void launch_transform_multi_query(T* query, + T* key, + T* value, + T* k_cache, + T* v_cache, + const T* vals, + int batch_size, + int seq_length, + int all_tokens, + int hidden_dim, + int heads, + int num_kv, + cudaStream_t stream, + int max_out_tokens) +{ + hidden_dim >>= 3; + int max_thread_blk = 1024 / (hidden_dim / heads); + int threadblks = (heads / num_kv) * 3; + int launch_blks = threadblks; + if (launch_blks > max_thread_blk) launch_blks = max_thread_blk; + int num_blks = (threadblks - 1) / launch_blks + 1; + dim3 block_dim(hidden_dim / heads, launch_blks); + dim3 grid_dim(batch_size, all_tokens, num_kv * num_blks); + + transform_multi_query<<>>( + query, + key, + value, + k_cache, + v_cache, + vals, + seq_length, + all_tokens, + heads, + num_kv, + num_blks, + heads / num_kv, + hidden_dim + num_kv * 2 * (hidden_dim / heads), + (heads / num_kv + 2) * (hidden_dim / heads), + hidden_dim, + max_out_tokens); +} + +#define INSTANTIATE_LAUNCH_TRANSFORM_MULTI_QUERY(T) \ + template void launch_transform_multi_query( \ + T*, T*, T*, T*, T*, const T*, int, int, int, int, int, int, cudaStream_t, int) + +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_TRANSFORM_MULTI_QUERY(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_TRANSFORM_MULTI_QUERY(__half); diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 8ba8c1c3e22c..49676c40a1ba 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -169,7 +169,8 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned num_heads, unsigned batch, cudaStream_t stream, - int max_out_tokens); + int max_out_tokens, + bool multi_query); template void launch_moe_res_matmul(T* residual, @@ -243,3 +244,18 @@ void launch_vector_add(T* out, float gamma, int num_elems, cudaStream_t stream); +template +void launch_transform_multi_query(T* query, + T* key, + T* value, + T* k_cache, + T* v_cache, + const T* vals, + int batch_size, + int seq_length, + int all_tokens, + int hidden_dim, + int heads, + int num_kv, + cudaStream_t stream, + int max_out_tokens); diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 396fe7db2447..d80f0bf612b3 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -108,6 +108,10 @@ def tp_parser(model): gem_list = gem_list + [layer] elif 'down_proj' in layer: gem_list = gem_list + [layer] + elif 'self_attention.dense' in layer and 'falcon' in str( + type(module)): # this is a hack to get the right linear layer for this model! + gem_list = gem_list + [layer] + layer_list = [] if gem_list != []: gem_list = list(set(gem_list)) diff --git a/deepspeed/module_inject/containers/__init__.py b/deepspeed/module_inject/containers/__init__.py index 1dab38b73f51..8bc36f9a5cc5 100644 --- a/deepspeed/module_inject/containers/__init__.py +++ b/deepspeed/module_inject/containers/__init__.py @@ -17,3 +17,4 @@ from .clip import DS_CLIPContainer, HFCLIPLayerPolicy from .unet import UNetPolicy from .vae import VAEPolicy +from .falcon import FALCONLayerPolicy, DS_FALCONContainer diff --git a/deepspeed/module_inject/containers/base.py b/deepspeed/module_inject/containers/base.py index a520664793ca..36358c4c316f 100644 --- a/deepspeed/module_inject/containers/base.py +++ b/deepspeed/module_inject/containers/base.py @@ -10,7 +10,6 @@ from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig from deepspeed.accelerator import get_accelerator - # If the intermediate size attribute is set DEFAULT_INTERMEDIATE_SIZE # it is assumed the intermediate size is 4x the embedding dimension DEFAULT_INTERMEDIATE_SIZE = -1 @@ -59,7 +58,7 @@ def __init__(self, policy, config, model_config, layer_id, child): self.scale_attn_by_inverse_layer_idx = getattr(self.config, "scale_attn_by_inverse_layer_idx", False) self.use_mup = self.policy.use_mup self.return_single_tuple = False - self.rotary_dim = self.get_rotary_dim() + self.rotary_dim = self.get_rotary_dim(policy) self.mlp_after_attn = (self.rotary_dim is None or self.rotary_dim < 0) # Attention tensors @@ -140,11 +139,15 @@ def convert_to_required_dtype(self): if isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter): self.__dict__[k] = v.to(self.dtype) - def get_rotary_dim(self): + def get_rotary_dim(self, policy=None): + + from .falcon import FALCONLayerPolicy if hasattr(self.model_config, 'rotary_dim'): return self.model_config.rotary_dim if hasattr(self.child, 'attention') and hasattr(self.child.attention, 'rotary_ndims'): return self.child.attention.rotary_ndims + if policy.__class__ is FALCONLayerPolicy: + return self.child.self_attention.head_dim return -1 def set_moe(self, moe=False): diff --git a/deepspeed/module_inject/containers/falcon.py b/deepspeed/module_inject/containers/falcon.py new file mode 100644 index 000000000000..d10aa1d74064 --- /dev/null +++ b/deepspeed/module_inject/containers/falcon.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .base import * +from .features.meta_tensor import MetaTensorContainer +from .features.hybrid_engine import HybridEngineContainer +from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference + +from ..policy import transformer_param_names +from ..policy import ( + TransformerPolicy, + maybe_copy, + maybe_get_lora, +) + + +class DS_FALCONContainer(MetaTensorContainer, BaseTransformerContainer, HybridEngineContainer): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # All model specific things should be defined here instead of the base class. + + def create_module(self, config=None): + _config = config if config is not None else self.ds_model_config + + _config.rotate_half = True + _config.rotate_every_two = False + _config.rotary_dim = self.hidden_size // self.num_attention_heads + _config.multi_query = True + _config.global_kv_sharing = (self.policy.num_kv == 1) and self.policy.config.multi_query + _config.num_kv = self.policy.num_kv + _config.mlp_after_attn = not self.policy.config.parallel_attn + self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group) + + self.module.config.rotate_half = True + self.module.config.rotate_every_two = False + + return self.module + + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + attention = self.policy.client_module.self_attention + + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.mlp.dense_h_to_4h, self.policy.client_module.mlp.dense_4h_to_h, + attention.query_key_value, attention.dense + ] + ] + + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): + ln_2 = 'ln_mlp' if hasattr(self.policy.client_module, 'ln_mlp') else ('post_attention_layernorm' if hasattr(self.policy.client_module, 'post_attention_layernorm') else None) + ln_1 = 'ln_attn' if hasattr(self.policy.client_module, 'ln_attn') else 'input_layernorm' + param_names = ( + 'self_attention.query_key_value.weight', \ + 'self_attention.dense.weight', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_4h_to_h.weight', \ + None if ln_2 is None else ln_2 + '.weight', \ + None if ln_2 is None else ln_2 + '.bias', \ + ln_1 + '.weight', \ + ln_1 + '.bias' + ) + for i in range(0, 2): + maybe_copy(module.attention, + sd, + weight_quantizer, + mp_replace, + transformer_param_names[i * 2], + prefix + param_names[i], + qkv=True, + megatron_v2=self.policy.is_megatron_v2, + split_qkv=self.policy.split_qkv) + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[4], + prefix + param_names[2]) + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[6], + prefix + param_names[3]) + if ln_2 is not None: + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], + prefix + param_names[4]) + if ln_2 is not None: + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[9], + prefix + param_names[5]) + for i in range(6, 8): + maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i + 4], prefix + param_names[i]) + del sd + + def attention_qkv_mp(self, mp_replace, reversed_dim=False): + self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, + self.qkvw, + int8=reversed_dim) + self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, + self.qkvb, + int8=reversed_dim) + + +class FALCONLayerPolicy(TransformerPolicy): + + def __init__(self, client_module, inference=True): + super().__init__(inference, split_qkv=False) + self.client_module = client_module + FALCONLayerPolicy.name = 'falcon' + FALCONLayerPolicy._orig_layer_class = None + if client_module is not None: + self.num_kv = self.client_module.self_attention.num_kv if hasattr(self.client_module.self_attention, 'num_kv') else \ + self.client_module.self_attention.num_kv_heads + self.config = self.client_module.config + + def get_hidden_heads(self): + heads = self.client_module.self_attention.num_heads + return self.client_module.self_attention.query_key_value.weight.shape[1], \ + heads, \ + self.client_module.ln_mlp.eps if hasattr(self.client_module, 'ln_mlp') else self.client_module.input_layernorm.eps, \ + DEFAULT_INTERMEDIATE_SIZE + + def attention(self, enable_training=False): + attention = self.client_module.self_attention + + return attention.query_key_value.weight, \ + None, \ + attention.dense.weight, \ + None + + def mlp(self, enable_training=False): + return self.client_module.mlp.dense_h_to_4h.weight, \ + None, \ + self.client_module.mlp.dense_4h_to_h.weight, \ + None + + def layernorm(self): + ln_2 = self.client_module.ln_mlp if hasattr(self.client_module, 'ln_mlp') else (self.client_module.post_attention_layernorm if hasattr(self.client_module, 'post_attention_layernorm') else None) + ln_1 = self.client_module.ln_attn if hasattr(self.client_module, 'ln_attn') else self.client_module.input_layernorm + #import pdb;pdb.set_trace() + return ln_2.weight if ln_2 is not None else None, \ + ln_2.bias if ln_2 is not None else None, \ + ln_1.weight, \ + ln_1.bias diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 298a5081e78e..c67711e9e2ca 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -225,7 +225,7 @@ def load_module_recursive(module, prefix='', level=0): child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps) setattr(module, name, child) elif child.__class__ is nn.Linear: - child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias) + child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias, dtype=child.weight.dtype) setattr(module, name, child) elif child.__class__ is OPTLearnedPositionalEmbedding: child = OPTEmbedding(weight_shape=ds_shape) diff --git a/deepspeed/module_inject/policy.py b/deepspeed/module_inject/policy.py index 41df2b85dc0c..c2d5d1bce9ee 100644 --- a/deepspeed/module_inject/policy.py +++ b/deepspeed/module_inject/policy.py @@ -168,6 +168,7 @@ def maybe_copy(module, tmp = transpose(tmp) dst = mp_replace.copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \ transpose(tmp)), int8=weight_quantizer.q_int8) + del tmp, sd setattr(module, dst_name, dst) @@ -191,6 +192,7 @@ def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names else: dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \ transpose(qkv_data)), int8=weight_quantizer.q_int8) + del q, k, v, sd setattr(module, dst_name, dst) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index f0fe81f28714..f6fa77426e18 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -40,7 +40,7 @@ def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0): self.mp_size = mp_size def merge_assert(self, dim1, dim2): - assert dim1 > dim2, \ + assert dim1 >= dim2, \ 'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\ for merging your checkpoints before replacing the transformer layer with\ inference-kernels' @@ -438,6 +438,10 @@ def update_mp_params(child): assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format( child.num_heads, mp_size) child.num_heads = child.num_heads // mp_size + if hasattr(child, 'num_kv'): + assert child.num_kv % mp_size == 0, "num_kv ({}) must be divisible by mp_size ({})".format( + child.num_kv, mp_size) + child.num_kv = child.num_kv // mp_size if hasattr(child, 'num_attention_heads'): assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format( child.num_attention_heads, mp_size) @@ -492,9 +496,17 @@ def _replace_module(r_module, prev_name='', prev_class_name=''): continue if len(child._buffers) != 0 and state_dict != None: load_buffer(child, state_dict, checking_key) - if child.__class__ in linear_policies: - setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, - conv_linear_layer)) + if any(isinstance(child, lp) for lp in linear_policies): + if child.__class__ in linear_policies: + setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, + conv_linear_layer)) + else: + key = None + for lp in linear_policies: + if isinstance(child, lp): + key = lp + assert key is not None + setattr(r_module, name, linear_policies[key](child, prev_name + '.' + name, conv_linear_layer)) else: update_mp_params(child) _replace_module(child, name, class_name) @@ -522,6 +534,7 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): return new_module if checkpoint_dict != None and not config.replace_with_kernel_inject: + # AutoTP shard loading checkpoint = checkpoint_dict["checkpoints"] pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") @@ -557,7 +570,11 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") for i in range(len(checkpoint)): - sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')] + if checkpoint[i].endswith(".safetensors"): + from safetensors.torch import load_file + sd = [load_file(os.path.join(base_dir1, checkpoint[i]), device=f'cuda:{torch.distributed.get_rank()}')] + else: + sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')] load_model_with_checkpoint(replaced_module, sd, mp_replace, @@ -565,6 +582,13 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): ckpt_mp_size, quantizer, container=container_g) + for sd_ in sd: + lens = len(sd_.keys()) + for _ in range(lens): + data = sd_.popitem() + del data + del sd_ + sd = None pbar.update(1) else: num_checkpoints = len(ckpt_list) // ckpt_mp_size @@ -773,6 +797,8 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No policy.update({orig_layer_class: (replace_fn, plcy)}) elif plcy._orig_layer_class is not None: policy.update({plcy._orig_layer_class: (replace_fn, plcy)}) + elif hasattr(plcy, 'name'): + policy.update({plcy.name: (replace_fn, plcy)}) assert len(policy.items()) > 0,\ "No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\ "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py" @@ -833,12 +859,18 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di OPTLearnedPositionalEmbedding = None load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm, OPTLearnedPositionalEmbedding] for name, child in model.named_children(): - if child.__class__ in policies: - replaced_module = policies[child.__class__][0](child, - policies[child.__class__][-1], - layer_id, - prefix=prefix + name, - state_dict=state_dict) + key = child.__class__ + if key in policies or (isinstance(model, nn.ModuleList) + and any(pname in str(key) for pname in policies if isinstance(pname, str))): + if not key in policies: + for pname in policies: + if isinstance(pname, str) and pname in str(key): + key = pname + replaced_module = policies[key][0](child, + policies[key][-1], + layer_id, + prefix=prefix + name, + state_dict=state_dict) setattr(model, name, replaced_module) if isinstance(model, PipelineModule): assert hasattr(model, 'forward_funcs'),\ @@ -847,7 +879,7 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di layer_id += 1 else: checking_key = prefix + name + '.' - if child.__class__ in load_layers and state_dict != None: + if key in load_layers and state_dict != None: if any(checking_key in item for item in state_dict): load( child, diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index c49b8f81c430..ce033484a7e9 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -16,11 +16,22 @@ from .containers import LLAMALayerPolicy from .containers import UNetPolicy from .containers import VAEPolicy +from .containers import FALCONLayerPolicy # transformer-based policies replace_policies = [ - HFBertLayerPolicy, HFGPTNEOLayerPolicy, GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy, - HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy, LLAMALayerPolicy + HFBertLayerPolicy, + HFGPTNEOLayerPolicy, + GPTNEOXLayerPolicy, + HFGPTJLayerPolicy, + MegatronLayerPolicy, + HFGPT2LayerPolicy, + BLOOMLayerPolicy, + HFOPTLayerPolicy, + HFCLIPLayerPolicy, + HFDistilBertLayerPolicy, + LLAMALayerPolicy, + FALCONLayerPolicy, ] # non-transformer-based policies diff --git a/deepspeed/module_inject/utils.py b/deepspeed/module_inject/utils.py index c442d24fd3b6..6481304f8738 100644 --- a/deepspeed/module_inject/utils.py +++ b/deepspeed/module_inject/utils.py @@ -18,6 +18,7 @@ def policy_to_ds_container(**kwargs): from .containers import MegatronLayerPolicy, DS_MegatronGPTContainer from .containers import HFDistilBertLayerPolicy, DS_DistilBERTContainer from .containers import LLAMALayerPolicy, DS_LLAMAContainer + from .containers import FALCONLayerPolicy, DS_FALCONContainer policy_to_container = { HFGPT2LayerPolicy: DS_GPT2Container, @@ -30,6 +31,7 @@ def policy_to_ds_container(**kwargs): MegatronLayerPolicy: DS_MegatronGPTContainer, HFDistilBertLayerPolicy: DS_DistilBERTContainer, LLAMALayerPolicy: DS_LLAMAContainer, + FALCONLayerPolicy: DS_FALCONContainer, } container = None diff --git a/deepspeed/ops/transformer/inference/config.py b/deepspeed/ops/transformer/inference/config.py index 261523529d0b..064b55125ad5 100644 --- a/deepspeed/ops/transformer/inference/config.py +++ b/deepspeed/ops/transformer/inference/config.py @@ -77,7 +77,10 @@ def __init__(self, scale_attn_by_inverse_layer_idx=False, return_single_tuple=False, set_empty_params=False, - transposed_mode=False): + transposed_mode=False, + multi_query=False, + global_kv_sharing=False, + num_kv=-1): super(DeepSpeedInferenceConfig, self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, num_hidden_layers) @@ -109,6 +112,9 @@ def __init__(self, self.return_single_tuple = return_single_tuple self.set_empty_params = set_empty_params self.transposed_mode = transposed_mode + self.multi_query = multi_query + self.global_kv_sharing = global_kv_sharing + self.num_kv = num_kv @classmethod def from_dict(cls, json_object): diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 967f1d4b8d9d..50903b4dd2de 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -37,7 +37,14 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count self.attn_ow = None self.attn_ob = None else: - qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + if self.config.multi_query: + if self.config.global_kv_sharing: + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) + 2 * (self.config.hidden_size // self.config.heads) + else: + qkv_size_per_partition = (self.config.hidden_size // self.config.heads) * ( + self.config.num_kv * 2 + self.config.heads) // self.config.mp_size + else: + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, qkv_size_per_partition, dtype=data_type, @@ -154,11 +161,13 @@ def forward(self, bias=self._attn_qkvb, gamma=norm_w, beta=norm_b) - + context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out, input_mask=input_mask, layer_past=layer_past, alibi=alibi) + #print(context_layer) + #exit() output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) inp_norm = qkv_out[-1] diff --git a/deepspeed/ops/transformer/inference/ds_mlp.py b/deepspeed/ops/transformer/inference/ds_mlp.py index f4bb538dab37..63c3e94266ff 100644 --- a/deepspeed/ops/transformer/inference/ds_mlp.py +++ b/deepspeed/ops/transformer/inference/ds_mlp.py @@ -111,7 +111,6 @@ def forward(self, input, residual, residual_norm, bias): bias=self.inter_b, gamma=self.attn_nw, beta=self.attn_nb) - residual = self.residual_add_func(hidden_state=output, residual=residual, add_bias=bias is not None, @@ -119,6 +118,7 @@ def forward(self, input, residual, residual_norm, bias): attention_bias=bias if bias is not None else self.output_b, final_bias=self.output_b, residual_add=residual_add) + #print(f'mlp_out: {}') if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: dist.all_reduce(residual, group=self.mp_group) diff --git a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py index e3e372d60080..bdbd64bda052 100644 --- a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py @@ -12,6 +12,8 @@ from .base import BaseOp from deepspeed.utils.types import NormType +dummy_tensor = torch.Tensor([]) + class MLPGemmOp(BaseOp): @@ -68,10 +70,10 @@ def forward(self, output, residual_add = self.mlp_gemm_func( input, residual, - input_bias, + input_bias if input_bias is not None else dummy_tensor, weight_interm, weight_out, - bias, + bias if bias is not None else dummy_tensor, gamma, beta, self.config.epsilon, diff --git a/deepspeed/ops/transformer/inference/op_binding/residual_add.py b/deepspeed/ops/transformer/inference/op_binding/residual_add.py index 37f964e02849..7b2fc3b18ee3 100644 --- a/deepspeed/ops/transformer/inference/op_binding/residual_add.py +++ b/deepspeed/ops/transformer/inference/op_binding/residual_add.py @@ -9,6 +9,8 @@ from ..config import DeepSpeedInferenceConfig from .base import BaseOp +dummy_tensor = torch.Tensor([]) + class ResidualAddOp(BaseOp): @@ -38,16 +40,17 @@ def forward(self, final_bias: Optional[torch.Tensor] = None): if self.residual_add_func != None: - if final_bias is None: + if final_bias is None and self.config.mlp_after_attn: residual = self._vector_add(residual, hidden_state, 1.0 / self.config.mp_size) else: if not self.config.pre_layer_norm and residual_add is not None: # only use residual add if its set and we are not pre layer norm residual = residual_add - self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias, - self.config.mp_size, self.config.mlp_after_attn, add_bias, - self.config.pre_layer_norm) + self.residual_add_func(hidden_state, residual, attention_output, + attention_bias if attention_bias is not None else dummy_tensor, + final_bias if final_bias is not None else dummy_tensor, self.config.mp_size, + self.config.mlp_after_attn, add_bias, self.config.pre_layer_norm) else: # fallback if os.environ.get('DS_KI_FALLBACK') == 'True' and self.config.mlp_after_attn: diff --git a/deepspeed/ops/transformer/inference/op_binding/softmax_context.py b/deepspeed/ops/transformer/inference/op_binding/softmax_context.py index 3cc75860a752..3e453dfe5b84 100644 --- a/deepspeed/ops/transformer/inference/op_binding/softmax_context.py +++ b/deepspeed/ops/transformer/inference/op_binding/softmax_context.py @@ -22,6 +22,7 @@ def __init__(self, config: DeepSpeedInferenceConfig): self.softmax_context_func = self.inference_module.softmax_context_fp32 except AttributeError: self.softmax_context_func = self.softmax_context_fallback + self.num_kv_per_partition = self.config.num_kv // self.config.mp_size def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, roteate_every_two, heads, norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id, @@ -41,6 +42,7 @@ def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: output = self.softmax_context_func(query_key_value, attn_mask, self.config.rotary_dim, self.config.rotate_half, self.config.rotate_every_two, heads, norm_factor, self.config.triangular_masking, self.config.local_attention, - self.config.window_size, no_masking, layer_id, num_layers, alibi) + self.config.window_size, no_masking, layer_id, num_layers, alibi, + self.config.multi_query, self.num_kv_per_partition) return output diff --git a/test_falcon.py b/test_falcon.py new file mode 100644 index 000000000000..4c922ba8dd33 --- /dev/null +++ b/test_falcon.py @@ -0,0 +1,59 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import transformers +import torch +import deepspeed +import time +from deepspeed.accelerator import get_accelerator +import json +import io +import os +from pathlib import Path +from argparse import ArgumentParser + +parser = ArgumentParser() +parser.add_argument("--save_mp_sharded_ckpt", required=False, action='store_true') +parser.add_argument("--local_rank", type=int, default=-1) +parser.add_argument("--model-name", type=str, default='falcon-40b') +parser.add_argument("--ckpt-root", type=str, default='falcon-40b') +args = parser.parse_args() +repo_root = args.ckpt_root +model = "tiiuae/"+args.model_name +#AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True) +if args.save_mp_sharded_ckpt: + checkpoints_json = "checkpoints.json" + with io.open(checkpoints_json, "w", encoding="utf-8") as f: + file_list = [str(entry).split('/')[-1] for entry in Path(repo_root).rglob("*.[bp][it][n]") if entry.is_file()] + if len(file_list) == 0: + file_list = [str(entry).split('/')[-1] for entry in Path(repo_root).rglob("*.safetensors") if entry.is_file()] + data = {"type": "ds_model", "checkpoints": file_list, "version": 1.0} + json.dump(data, f) +else: + checkpoints_json = "/tmp/" + args.model_name + "/ds_inference_config.json" +tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) +config = AutoConfig.from_pretrained(model, trust_remote_code=True) + +with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + +model = deepspeed.init_inference(model, + mp_size=int(os.getenv("WORLD_SIZE", "1")), + replace_with_kernel_inject=True, + base_dir=repo_root, + dtype=torch.bfloat16, + checkpoint=checkpoints_json, +# save_mp_checkpoint_path='/tmp/'+args.model_name if args.save_mp_sharded_ckpt else None + ) + +input_prompt = [ + "Deep learning involves the use of neural networks" + ] +input_tokens = tokenizer.batch_encode_plus(input_prompt, return_tensors="pt",) +token_num = input_tokens['input_ids'].size(-1) +for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(get_accelerator().current_device_name()) +input_tokens.pop('token_type_ids') +sequences = model.generate(**input_tokens, min_length=100, max_length=100, do_sample=True) + +if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + print(f"Result: {tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]}")