diff --git a/csrc/fused_dense_lib/fused_dense.cpp b/csrc/fused_dense_lib/fused_dense.cpp index eacd153b1..924941002 100644 --- a/csrc/fused_dense_lib/fused_dense.cpp +++ b/csrc/fused_dense_lib/fused_dense.cpp @@ -28,19 +28,19 @@ } template -int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace); +int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias); template -int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ; +int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act); template -int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace); +int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias); std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) { - int batch_size = input.size(0); - int in_features = input.size(1); - int out_features = d_output.size(1); + int64_t batch_size = input.size(0); + int64_t in_features = input.size(1); + int64_t out_features = d_output.size(1); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.dtype() == d_output.dtype()); @@ -66,8 +66,6 @@ std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output, d_bias = at::empty({out_features}, opts); #endif } - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, opts); DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] { auto result = linear_bias_wgrad_cuda( @@ -77,21 +75,20 @@ std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output, batch_size, out_features, d_weight.data_ptr(), - has_d_bias ? d_bias.data_ptr() : nullptr, - (void*) (lt_workspace.data_ptr())); + has_d_bias ? d_bias.data_ptr() : nullptr); TORCH_CHECK(result == 0, "linear_bias_wgrad failed."); }); return {d_weight, d_bias}; } -std::vector linear_gelu_forward(at::Tensor input, at::Tensor weight, - c10::optional bias_, - bool save_gelu_in, int heuristic) { +std::vector linear_act_forward(at::Tensor input, at::Tensor weight, + c10::optional bias_, + bool is_gelu, bool save_pre_act, int heuristic) { - int batch_size = input.size(0); - int in_features = input.size(1); - int out_features = weight.size(0); + int64_t batch_size = input.size(0); + int64_t in_features = input.size(1); + int64_t out_features = weight.size(0); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.dtype() == weight.dtype()); @@ -116,51 +113,52 @@ std::vector linear_gelu_forward(at::Tensor input, at::Tensor weight, // create output/workspace tensor auto opts = input.options(); auto output = at::empty({batch_size, out_features}, opts); - at::Tensor gelu_in; - if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); } - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, opts); + at::Tensor pre_act; + // If ReLU, cuBlasLT stores a bit-mask (1 bit per element) + if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8}, + is_gelu ? opts : opts.dtype(torch::kUInt8)); } - DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] { - auto result = linear_gelu_forward_cuda( + DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] { + auto result = linear_act_forward_cuda( input.data_ptr(), weight.data_ptr(), bias_.has_value()? bias_.value().data_ptr() : nullptr, in_features, batch_size, out_features, + is_gelu, heuristic, output.data_ptr(), - save_gelu_in ? gelu_in.data_ptr() : nullptr, - (void*) (lt_workspace.data_ptr())); - TORCH_CHECK(result == 0, "linear_gelu_forward failed."); + save_pre_act ? pre_act.data_ptr() : nullptr); + TORCH_CHECK(result == 0, "linear_act_forward failed."); }); std::vector result = {output}; - if (save_gelu_in) { result.push_back(gelu_in); }; + if (save_pre_act) { result.push_back(pre_act); }; return result; } -std::vector bias_gelu_linear_dgrad_bgrad( - at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic +std::vector bias_act_linear_dgrad_bgrad( + at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic ) { - int batch_size = d_output.size(0); - int out_features = d_output.size(1); - int in_features = weight.size(1); + int64_t batch_size = d_output.size(0); + int64_t out_features = d_output.size(1); + int64_t in_features = weight.size(1); TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16); TORCH_CHECK(weight.dtype() == d_output.dtype()); - TORCH_CHECK(weight.dtype() == gelu_in.dtype()); + TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8)); TORCH_CHECK(weight.is_cuda()); TORCH_CHECK(d_output.is_cuda()); - TORCH_CHECK(gelu_in.is_cuda()); + TORCH_CHECK(pre_act.is_cuda()); TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(d_output.is_contiguous()); - TORCH_CHECK(gelu_in.is_contiguous()); + TORCH_CHECK(pre_act.is_contiguous()); CHECK_SHAPE(weight, out_features, in_features); CHECK_SHAPE(d_output, batch_size, out_features); - CHECK_SHAPE(gelu_in, batch_size, in_features); + // If ReLU, cuBlasLT stores a bit-mask (1 bit per element) + CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing @@ -170,22 +168,20 @@ std::vector bias_gelu_linear_dgrad_bgrad( auto opts = weight.options(); auto d_bias = at::empty({in_features}, opts); auto d_input = at::empty({batch_size, in_features}, opts); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, opts); - DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] { - auto result = bias_gelu_linear_dgrad_bgrad_cuda( + DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] { + auto result = bias_act_linear_dgrad_bgrad_cuda( weight.data_ptr(), d_output.data_ptr(), - gelu_in.data_ptr(), + pre_act.data_ptr(), in_features, batch_size, out_features, + is_gelu, heuristic, d_input.data_ptr(), - d_bias.data_ptr(), - (void*) (lt_workspace.data_ptr())); - TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed."); + d_bias.data_ptr()); + TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed."); }); return {d_input, d_bias}; @@ -193,6 +189,6 @@ std::vector bias_gelu_linear_dgrad_bgrad( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad"); - m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward"); - m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad"); + m.def("linear_act_forward", &linear_act_forward, "linear gelu/relu forward"); + m.def("bias_act_linear_dgrad_bgrad", &bias_act_linear_dgrad_bgrad, "bias gelu/relu linear dgrad bgrad"); } diff --git a/csrc/fused_dense_lib/fused_dense_cuda.cu b/csrc/fused_dense_lib/fused_dense_cuda.cu index 8243d0bd1..7b6f39200 100644 --- a/csrc/fused_dense_lib/fused_dense_cuda.cu +++ b/csrc/fused_dense_lib/fused_dense_cuda.cu @@ -12,7 +12,6 @@ #include #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 -// includes cublaslt #include #endif @@ -21,17 +20,17 @@ cublasStatus_t gemm_bias( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, - int n, - int k, + int64_t m, + int64_t n, + int64_t k, const float* alpha, - at::Half* A, - int lda, - at::Half* B, - int ldb, + const at::Half* A, + int64_t lda, + const at::Half* B, + int64_t ldb, const float* beta, at::Half* C, - int ldc) { + int64_t ldc) { return cublasGemmEx( handle, transa, @@ -59,17 +58,17 @@ cublasStatus_t gemm_bias( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, - int n, - int k, + int64_t m, + int64_t n, + int64_t k, const float* alpha, - at::BFloat16* A, - int lda, - at::BFloat16* B, - int ldb, + const at::BFloat16* A, + int64_t lda, + const at::BFloat16* B, + int64_t ldb, const float* beta, at::BFloat16* C, - int ldc) { + int64_t ldc) { return cublasGemmEx( handle, transa, @@ -94,28 +93,40 @@ cublasStatus_t gemm_bias( #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 -int gemm_bias_gelu_lt( - cublasLtHandle_t ltHandle, +template +int gemm_bias_act_lt( cublasOperation_t transa, cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const Dtype* A, + int64_t lda, + const Dtype* B, + int64_t ldb, + const Dtype* bias, + Dtype* C, int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - int heuristic, - const void* gelu_in, - const void* bias) { - bool save_gelu_in = gelu_in != nullptr; + void* pre_act, + bool is_gelu, + int heuristic + ) { + static_assert(std::is_same::value || std::is_same::value, + "gemm_bias_act_lt only supports fp16 and bf16"); + bool save_pre_act = pre_act != nullptr; + float beta = 0.0; + cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; + + cublasLtHandle_t ltHandle = + reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = 1024 * 1024; + void* workspace = at::empty( + {static_cast(workspaceSize)}, + at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr(); + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; @@ -125,7 +136,11 @@ int gemm_bias_gelu_lt( int returnedResults = 0; constexpr int requestedAlgoCount = 5; cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; - cublasLtEpilogue_t epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU; + // constexpr int requestedAlgoCount = 1; + // cublasLtMatmulHeuristicResult_t heuristicResult = {}; + cublasLtEpilogue_t epilogue = is_gelu + ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU) + : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU); // Create operation descriptor; see cublasLtMatmulDescAttributes_t // for details about defaults; here we just set the transforms for @@ -137,8 +152,8 @@ int gemm_bias_gelu_lt( status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - if (save_gelu_in) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); + if (save_pre_act) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act)); status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); } @@ -147,9 +162,13 @@ int gemm_bias_gelu_lt( if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } - epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS; + epilogue = is_gelu + ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS) + : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS); } else { - epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU; + epilogue = is_gelu + ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU) + : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU); } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); @@ -159,12 +178,12 @@ int gemm_bias_gelu_lt( // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); + status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be @@ -184,6 +203,7 @@ int gemm_bias_gelu_lt( // run them one by one until something works. status = cublasLtMatmulAlgoGetHeuristic( ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); + // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (returnedResults == 0) { @@ -192,12 +212,12 @@ int gemm_bias_gelu_lt( } status = cublasLtMatmul(ltHandle, &operationDesc, - alpha, + &alpha, A, &Adesc, B, &Bdesc, - beta, + &beta, C, &Cdesc, C, @@ -208,7 +228,7 @@ int gemm_bias_gelu_lt( // NULL, workspace, workspaceSize, - stream); + at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already @@ -216,147 +236,71 @@ CLEANUP: return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } -int gemm_bias_gelu_lt( - cublasLtHandle_t ltHandle, +template int gemm_bias_act_lt( cublasOperation_t transa, cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::BFloat16* A, - int lda, - at::BFloat16* B, - int ldb, - const float *beta, /* host pointer */ - at::BFloat16* C, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const at::Half* A, + int64_t lda, + const at::Half* B, + int64_t ldb, + const at::Half* bias, + at::Half* C, int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - int heuristic, - const void* gelu_in, - const void* bias) { - bool save_gelu_in = gelu_in != nullptr; - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - constexpr int requestedAlgoCount = 5; - cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; - cublasLtEpilogue_t epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (save_gelu_in) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); - } - - if (bias != nullptr) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS; - } else { - epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - // &heuristicResult.algo, - // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos - &heuristicResult[heuristic].algo, - // NULL, - workspace, - workspaceSize, - stream); + void* pre_act, + bool is_gelu, + int heuristic); -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} +template int gemm_bias_act_lt( + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const at::BFloat16* A, + int64_t lda, + const at::BFloat16* B, + int64_t ldb, + const at::BFloat16* bias, + at::BFloat16* C, + int64_t ldc, + void* pre_act, + bool is_gelu, + int heuristic); +template int gemm_bgradb_lt( - cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - const void* bgrad) { + int64_t m, + int64_t n, + int64_t k, + float alpha, + const Dtype* A, + int64_t lda, + const Dtype* B, + int64_t ldb, + Dtype* C, + int64_t ldc, + Dtype* bgrad) { + static_assert(std::is_same::value || std::is_same::value, + "gemm_bgradb_lt only supports fp16 and bf16"); + float beta = 0.0; + cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; + + cublasLtHandle_t ltHandle = + reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = 1024 * 1024; + void* workspace = at::empty( + {static_cast(workspaceSize)}, + at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr(); + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; @@ -392,12 +336,12 @@ int gemm_bgradb_lt( // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); + status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be @@ -425,12 +369,12 @@ int gemm_bgradb_lt( } status = cublasLtMatmul(ltHandle, &operationDesc, - alpha, + &alpha, A, &Adesc, B, &Bdesc, - beta, + &beta, C, &Cdesc, C, @@ -439,7 +383,7 @@ int gemm_bgradb_lt( NULL, workspace, workspaceSize, - stream); + at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already @@ -447,136 +391,69 @@ CLEANUP: return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } -int gemm_bgradb_lt( - cublasLtHandle_t ltHandle, + +template int gemm_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::BFloat16* A, - int lda, - at::BFloat16* B, - int ldb, - const float *beta, /* host pointer */ - at::BFloat16* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - const void* bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (bgrad != nullptr) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_BGRADB; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); + int64_t m, + int64_t n, + int64_t k, + float alpha, + const at::Half* A, + int64_t lda, + const at::Half* B, + int64_t ldb, + at::Half* C, + int64_t ldc, + at::Half* bgrad); -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} +template int gemm_bgradb_lt( + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const at::BFloat16* A, + int64_t lda, + const at::BFloat16* B, + int64_t ldb, + at::BFloat16* C, + int64_t ldc, + at::BFloat16* bgrad); -int gemm_dgelu_bgradb_lt( - cublasLtHandle_t ltHandle, +template +int gemm_dact_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const Dtype* A, + int64_t lda, + const Dtype* B, + int64_t ldb, + const void* pre_act, + Dtype* C, int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - int heuristic, - const void *gelu_in, - const void *bgrad) { + Dtype* bgrad, + bool is_gelu, + int heuristic) { + static_assert(std::is_same::value || std::is_same::value, + "gemm_dact_bgradb_lt only supports fp16 and bf16"); + float beta = 0.0; + cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; + + cublasLtHandle_t ltHandle = + reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = 1024 * 1024; + void* workspace = at::empty( + {static_cast(workspaceSize)}, + at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr(); + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; @@ -586,7 +463,7 @@ int gemm_dgelu_bgradb_lt( int returnedResults = 0; constexpr int requestedAlgoCount = 5; cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; + cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD; // Create operation descriptor; see cublasLtMatmulDescAttributes_t // for details about defaults; here we just set the transforms for @@ -602,7 +479,7 @@ int gemm_dgelu_bgradb_lt( if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } @@ -615,12 +492,12 @@ int gemm_dgelu_bgradb_lt( // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); + status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be @@ -648,12 +525,12 @@ int gemm_dgelu_bgradb_lt( } status = cublasLtMatmul(ltHandle, &operationDesc, - alpha, + &alpha, A, &Adesc, B, &Bdesc, - beta, + &beta, C, &Cdesc, C, @@ -663,7 +540,7 @@ int gemm_dgelu_bgradb_lt( // NULL, workspace, workspaceSize, - stream); + at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already @@ -671,155 +548,69 @@ CLEANUP: return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } -int gemm_dgelu_bgradb_lt( - cublasLtHandle_t ltHandle, +template int gemm_dact_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::BFloat16* A, - int lda, - at::BFloat16* B, - int ldb, - const float *beta, /* host pointer */ - at::BFloat16* C, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const at::Half* A, + int64_t lda, + const at::Half* B, + int64_t ldb, + const void* pre_act, + at::Half* C, int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - int heuristic, - const void *gelu_in, - const void *bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + at::Half* bgrad, + bool is_gelu, + int heuristic); - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - constexpr int requestedAlgoCount = 5; - cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - &heuristicResult[heuristic].algo, - // NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} +template int gemm_dact_bgradb_lt( + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const at::BFloat16* A, + int64_t lda, + const at::BFloat16* B, + int64_t ldb, + const void* pre_act, + at::BFloat16* C, + int64_t ldc, + at::BFloat16* bgrad, + bool is_gelu, + int heuristic); #endif template -int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); +int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias) { const float alpha = 1.0; const float beta_zero = 0.0; int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 status = gemm_bgradb_lt( - (cublasLtHandle_t)handle, + // (cublasLtHandle_t)handle, CUBLAS_OP_N, CUBLAS_OP_T, in_features, out_features, batch_size, - &alpha, /* host pointer */ + alpha, input, in_features, d_output, out_features, - &beta_zero, /* host pointer */ d_weight, in_features, - lt_workspace, - 1 << 22, - stream, - static_cast(d_bias)); + d_bias); #endif if (status != 0){ + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); status = gemm_bias( handle, CUBLAS_OP_N, @@ -835,42 +626,48 @@ int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_siz &beta_zero, d_weight, in_features); + // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error + // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341 + // at::cuda::blas::gemm( + // 'N', + // 'T', + // in_features, + // out_features, + // batch_size, + // alpha, + // input, + // in_features, + // d_output, + // out_features, + // beta_zero, + // d_weight, + // in_features); } return status; } template -int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; +int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act) { int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 - status = gemm_bias_gelu_lt( - (cublasLtHandle_t)handle, + status = gemm_bias_act_lt( CUBLAS_OP_T, CUBLAS_OP_N, out_features, batch_size, in_features, - &alpha, /* host pointer */ + /*alpha=*/1.0, weight, in_features, input, in_features, - &beta_zero, /* host pointer */ + bias, output, out_features, - lt_workspace, - 1 << 22, - stream, - heuristic, - static_cast(gelu_in), - static_cast(bias)); + pre_act, + is_gelu, + heuristic); return status; #else return 1; @@ -878,46 +675,37 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int } template -int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); +int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias) { const float alpha = 1.0; - const float beta_zero = 0.0; int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 - status = gemm_dgelu_bgradb_lt( - (cublasLtHandle_t)handle, + status = gemm_dact_bgradb_lt( CUBLAS_OP_N, CUBLAS_OP_N, in_features, batch_size, out_features, - &alpha, /* host pointer */ + alpha, weight, in_features, d_output, out_features, - &beta_zero, /* host pointer */ + pre_act, d_input, in_features, - lt_workspace, - 1 << 22, - stream, - heuristic, - static_cast(gelu_in), - static_cast(d_bias)); + d_bias, + is_gelu, + heuristic); #endif return status; } -template int linear_bias_wgrad_cuda(at::Half *input, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace) ; -template int linear_bias_wgrad_cuda(at::BFloat16 *input, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace) ; +template int linear_bias_wgrad_cuda(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias); +template int linear_bias_wgrad_cuda(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias); -template int linear_gelu_forward_cuda(at::Half *input, at::Half *weight, at::Half *bias, int in_features, int batch_size, int out_features, int heuristic, at::Half *output, at::Half *gelu_in, void *lt_workspace) ; -template int linear_gelu_forward_cuda(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *bias, int in_features, int batch_size, int out_features, int heuristic, at::BFloat16 *output, at::BFloat16 *gelu_in, void *lt_workspace) ; +template int linear_act_forward_cuda(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act); +template int linear_act_forward_cuda(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act); -template int bias_gelu_linear_dgrad_bgrad_cuda(at::Half *weight, at::Half *d_output, at::Half *gelu_in, int in_features, int batch_size, int out_features, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace); -template int bias_gelu_linear_dgrad_bgrad_cuda(at::BFloat16 *weight, at::BFloat16 *d_output, at::BFloat16 *gelu_in, int in_features, int batch_size, int out_features, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace); \ No newline at end of file +template int bias_act_linear_dgrad_bgrad_cuda(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias); +template int bias_act_linear_dgrad_bgrad_cuda(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias); \ No newline at end of file diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 360eedc27..861fb4a86 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -23,7 +23,7 @@ from einops import rearrange from flash_attn.modules.mha import MHA -from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense +from flash_attn.modules.mlp import Mlp, FusedMLP from flash_attn.modules.block import Block from flash_attn.modules.embedding import BertEmbeddings from flash_attn.bert_padding import unpad_input, pad_input @@ -61,24 +61,24 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False): def create_mlp_cls(config, layer_idx=None, return_residual=False): inner_dim = config.intermediate_size - fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False) - if fused_dense_gelu_dense: - assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only ' + fused_mlp = getattr(config, 'fused_mlp', False) + if fused_mlp: + assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only ' 'supports approximate gelu') - if not fused_dense_gelu_dense: + if not fused_mlp: approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none' mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=partial(F.gelu, approximate=approximate), return_residual=return_residual) else: - if FusedDenseGeluDense is None: + if FusedMLP is None: raise ImportError('fused_dense is not installed') mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer if isinstance(mlp_checkpoint_lvl, Sequence): assert layer_idx is not None mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] - mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, + mlp_cls = partial(FusedMLP, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual) return mlp_cls diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index e2f9e6bb1..e37261c69 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -17,7 +17,7 @@ from einops import rearrange from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense +from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP from flash_attn.modules.block import Block from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.utils.distributed import sync_shared_params, all_gather_raw @@ -77,22 +77,22 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size - fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False) - if fused_dense_gelu_dense: - assert config.activation_function in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only ' - 'supports approximate gelu') + fused_mlp = getattr(config, 'fused_mlp', False) + if fused_mlp: + assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu'] fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False) if fused_dense_sqrelu_dense: assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only ' 'supports approximate activation_function sqrelu') - assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense) + assert not (fused_dense_sqrelu_dense and fused_mlp) if process_group is not None: - assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense' - if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense: + assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP' + if not fused_mlp and not fused_dense_sqrelu_dense: if config.activation_function == 'relu': activation = partial(F.relu, inplace=True) else: - approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none' + approximate = ('tanh' if config.activation_function + in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') activation=partial(F.gelu, approximate=approximate) mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs) else: @@ -101,14 +101,17 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp if isinstance(mlp_checkpoint_lvl, Sequence): assert layer_idx is not None mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] - if fused_dense_gelu_dense: - if FusedDenseGeluDense is None: + if fused_mlp: + if FusedMLP is None: raise ImportError('fused_dense is not installed') - mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense + activation = ('gelu_approx' if config.activation_function + in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu') + mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP parallel_kwargs = ({'process_group': process_group, 'sequence_parallel': getattr(config, 'sequence_parallel', True)} if process_group is not None else {}) - mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl, + mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation, + checkpoint_lvl=mlp_checkpoint_lvl, **parallel_kwargs, **factory_kwargs) elif fused_dense_sqrelu_dense: assert FusedDenseSqreluDense is not None @@ -210,7 +213,8 @@ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=No factory_kwargs = {'device': device, 'dtype': dtype} self.process_group = process_group self.sequence_parallel = getattr(config, 'sequence_parallel', True) - assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'relu', 'sqrelu'] + assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', + 'relu', 'sqrelu'] pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) diff --git a/flash_attn/models/vit.py b/flash_attn/models/vit.py index 646c5e715..312f5e495 100644 --- a/flash_attn/models/vit.py +++ b/flash_attn/models/vit.py @@ -20,7 +20,7 @@ from flash_attn.layers.patch_embed import PatchEmbed from flash_attn.modules.mha import MHA -from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense +from flash_attn.modules.mlp import Mlp, FusedMLP from flash_attn.modules.block import Block try: @@ -37,22 +37,22 @@ def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_ return mixer_cls -def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense): +def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp): inner_dim = int(embed_dim * mlp_ratio) - if not fused_dense_gelu_dense: + if not fused_mlp: mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) else: - mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim) + mlp_cls = partial(FusedMLP, hidden_features=inner_dim) return mlp_cls def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc, - fused_dense_gelu_dense, fused_dropout_add_ln, layer_idx=None, n_layer=None, + fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None, last_layer_subset=False): mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc, cross_attn=(last_layer_subset and layer_idx == n_layer - 1)) - mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense) + mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp) # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate, @@ -92,7 +92,7 @@ def __init__( act_layer=None, use_flash_attn=False, fused_bias_fc=False, - fused_dense_gelu_dense=False, + fused_mlp=False, fused_dropout_add_ln=False, ): """ @@ -164,7 +164,7 @@ def __init__( embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i], norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, - fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense, + fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp, fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth, last_layer_subset=(global_pool == 'token') ) for i in range(depth)]) diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index d6927024e..763c7beee 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -121,7 +121,8 @@ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, ) if mixer_kwargs is None: mixer_kwargs = {} - mixer_kwargs['mixer_subset'] = mixer_subset + if mixer_subset is not None: + mixer_kwargs['mixer_subset'] = mixer_subset hidden_states = self.mixer(hidden_states, **mixer_kwargs) if mixer_subset is not None: residual = residual[:, mixer_subset] diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 3771cdee3..5240e3f76 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -5,9 +5,9 @@ import torch.nn.functional as F try: - from flash_attn.ops.fused_dense import FusedDenseGeluDense, ParallelFusedDenseGeluDense + from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP except ImportError: - FusedDenseGeluDense, ParallelFusedDenseGeluDense = None, None + FusedMLP, ParallelFusedMLP = None, None class Mlp(nn.Module): diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index b4e3e28e0..dfb506e5c 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -1,8 +1,9 @@ -# Copyright (c) 2022, Tri Dao. +# Copyright (c) 2023, Tri Dao. # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py # We make it work with pytorch amp and with bfloat16. # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py from typing import Optional +from functools import partial import torch import torch.nn as nn @@ -19,6 +20,11 @@ from flash_attn.utils.distributed import reduce_scatter, all_reduce +@torch.jit.script +def relu_bwd(g, x): + return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) + + class FusedDenseFunc(torch.autograd.Function): @staticmethod @@ -185,12 +191,13 @@ def forward(self, x): return reduce_fn(out, self.process_group) -class FusedDenseGeluDenseFunc(torch.autograd.Function): +class FusedMLPFunc(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False, - checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True): + def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True, + return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None, + sequence_parallel=True): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul. @@ -198,10 +205,11 @@ def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_re checkpoint_lvl: 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute gelu_in and gelu_out in the bwd + 1: recompute gelu_out / relu_out in the bwd + 2: recompute pre_act and gelu_out / relu_out in the bwd """ assert -1 <= heuristic <= 4 + assert activation in ['gelu_approx', 'relu'] if not save_pre_act: checkpoint_lvl = 2 assert checkpoint_lvl in [0, 1, 2] @@ -209,6 +217,7 @@ def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_re ctx.process_group = process_group ctx.sequence_parallel = sequence_parallel ctx.checkpoint_lvl = checkpoint_lvl + ctx.activation = activation ctx.heuristic = heuristic if torch.is_autocast_enabled(): @@ -237,23 +246,27 @@ def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_re if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: raise RuntimeError('fused_dense only supports matrix dims <= 2M') if heuristic == -1: - gelu_in = F.linear(total_x, weight1, bias1) - output1 = F.gelu(gelu_in, approximate='tanh') + pre_act = F.linear(total_x, weight1, bias1) + activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' + else F.relu) + output1 = activation_fn(pre_act) # This is before adding bias1 - # gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) + # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) # with torch.jit.fuser('fuser2'): - # output1 = bias_gelu(gelu_in, bias1) + # output1 = bias_gelu(pre_act, bias1) else: - output1, *rest = fused_dense_cuda.linear_gelu_forward( - total_x.reshape(batch_dim, n), weight1, bias1, save_pre_act, heuristic + is_gelu = activation == 'gelu_approx' + output1, *rest = fused_dense_cuda.linear_act_forward( + total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic ) if save_pre_act: - gelu_in = rest[0] + pre_act = rest[0] output2 = F.linear(output1, weight2, bias2) - if checkpoint_lvl == 0: - ctx.save_for_backward(x, weight1, weight2, gelu_in, output1) + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'): + # For RELU the pre_act is very small (just a bit-mask) so we just save it + ctx.save_for_backward(x, weight1, weight2, pre_act, output1) elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, weight2, gelu_in) + ctx.save_for_backward(x, weight1, weight2, pre_act) elif checkpoint_lvl == 2: ctx.save_for_backward(x, weight1, weight2, bias1) output2 = output2.reshape(*batch_shape, output2.shape[-1]) @@ -264,6 +277,9 @@ def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_re def backward(ctx, grad_output, *args): grad_output = grad_output.contiguous() checkpoint_lvl = ctx.checkpoint_lvl + activation = ctx.activation + activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' + else F.relu) if ctx.return_residual: grad_input, = args grad_input = grad_input.contiguous() @@ -277,27 +293,27 @@ def backward(ctx, grad_output, *args): if checkpoint_lvl in [0, 1]: if process_group is not None and sequence_parallel: total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - if checkpoint_lvl == 0: - gelu_in, output1 = rest + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'): + pre_act, output1 = rest elif checkpoint_lvl == 1: - gelu_in, = rest - output1 = F.gelu(gelu_in, approximate='tanh') + pre_act, = rest + output1 = activation_fn(pre_act) elif checkpoint_lvl == 2: bias1, = rest if process_group is not None and sequence_parallel: total_x, _ = all_gather_raw(x, process_group) if ctx.heuristic == -1: - gelu_in = F.linear(total_x, weight1, bias1) - output1 = F.gelu(gelu_in, approximate='tanh') + pre_act = F.linear(total_x, weight1, bias1) + output1 = activation_fn(pre_act) else: - output1, gelu_in = fused_dense_cuda.linear_gelu_forward( - total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, True, - ctx.heuristic + output1, pre_act = fused_dense_cuda.linear_act_forward( + total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, + activation == 'gelu_approx', True, ctx.heuristic ) grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) output1 = output1.reshape(batch_dim, output1.shape[-1]) - gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1]) + pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) if ctx.needs_input_grad[3]: grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( output1, grad_output, ctx.needs_input_grad[4] @@ -306,24 +322,25 @@ def backward(ctx, grad_output, *args): grad_weight2 = None grad_bias2 = grad_output if ctx.needs_input_grad[4] else None if ctx.heuristic == -1: - # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in) + # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) grad_output1 = F.linear(grad_output, weight2.t()) with torch.jit.fuser('fuser2'): - grad_gelu = gelu_bwd(grad_output1, gelu_in) + activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd + grad_pre_act = activation_grad_fn(grad_output1, pre_act) else: - # The cublasLt epilogue has to compute both gelu grad and bias grad, we can't - # just compute gelu grad - grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad( - weight2, grad_output, gelu_in, ctx.heuristic + # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't + # just compute gelu/relu grad + grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( + weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic ) if not ctx.needs_input_grad[2]: grad_bias1 = None if ctx.needs_input_grad[0]: if not ctx.return_residual: - grad_input = F.linear(grad_gelu, weight1.t()) + grad_input = F.linear(grad_pre_act, weight1.t()) else: grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_gelu, weight1) + grad_pre_act, weight1) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw @@ -335,55 +352,60 @@ def backward(ctx, grad_output, *args): if process_group is not None and sequence_parallel: handle_x.wait() grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu, + total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act, ctx.needs_input_grad[2] ) else: grad_weight1 = None - grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None + grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None else: if ctx.needs_input_grad[1]: if process_group is not None and sequence_parallel: handle_x.wait() - grad_weight1 = F.linear(grad_gelu.t(), + grad_weight1 = F.linear(grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()) else: grad_weight1 = None if process_group is not None and ctx.needs_input_grad[0]: handle_grad_input.wait() return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, - None, None, None, None, None, None) + None, None, None, None, None, None, None) -def fused_dense_gelu_dense_func( +def fused_mlp_func( x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None, - bias2: Optional[Tensor] = None, + bias2: Optional[Tensor] = None, activation: str = 'gelu_approx', save_pre_act: bool = True, return_residual: bool = False, checkpoint_lvl: int = 0, heuristic: int = 0, process_group: Optional[ProcessGroup] = None, sequence_parallel: bool = True ): + assert activation in ['gelu_approx', 'relu'] dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] or (x.dtype == torch.float32 and torch.is_autocast_enabled())) + # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) + dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0) if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) - and (bias2 is None or bias2.is_cuda) and dtype_eligible): - return FusedDenseGeluDenseFunc.apply( - x, weight1, bias1, weight2, bias2, save_pre_act, return_residual, + and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible): + return FusedMLPFunc.apply( + x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group, sequence_parallel ) else: assert process_group is None - gelu_in = F.linear(x, weight1, bias1) - output1 = F.gelu(gelu_in, approximate='tanh') + pre_act = F.linear(x, weight1, bias1) + activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' + else partial(F.relu, inplace=True)) + output1 = activation_fn(pre_act) output2 = F.linear(output1, weight2, bias2) return output2 if not return_residual else (output2, x) -class FusedDenseGeluDense(nn.Module): +class FusedMLP(nn.Module): def __init__(self, in_features, hidden_features, out_features=None, bias1=True, - bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0, - device=None, dtype=None): + bias2=True, activation='gelu_approx', return_residual=False, + checkpoint_lvl=0, heuristic='auto', device=None, dtype=None): """ If process_group is not None, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul, gelu, then matmul. @@ -392,21 +414,24 @@ def __init__(self, in_features, hidden_features, out_features=None, bias1=True, checkpoint_lvl (increasing lvl means slower but more memory saving): 0: no recomputation in the bwd 1: recompute gelu_out in the bwd - 2: recompute gelu_in and gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd heuristic: -1: don't fuse gemm + gelu (separate kernel) 0..4: use this heuristic for the algo section in the fused gemm + gelu - For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16. + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. return_residual: whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. """ assert checkpoint_lvl in [0, 1, 2] + assert activation in ['gelu_approx', 'relu'] factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() if out_features is None: out_features = in_features + self.activation = activation self.return_residual = return_residual self.checkpoint_lvl = checkpoint_lvl self.heuristic = heuristic @@ -414,11 +439,20 @@ def __init__(self, in_features, hidden_features, out_features=None, bias1=True, self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) def forward(self, x, process_group=None): - out = fused_dense_gelu_dense_func( + dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() + if self.heuristic == 'auto': + if self.activation == 'gelu_approx': + cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) + heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + else: + heuristic = 0 + else: + heuristic = self.heuristic + out = fused_mlp_func( x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, - save_pre_act=self.training, return_residual=self.return_residual, - checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic, - process_group=process_group + activation=self.activation, save_pre_act=self.training, + return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, process_group=process_group ) if self.return_residual: out, x = out @@ -427,11 +461,12 @@ def forward(self, x, process_group=None): return out if not self.return_residual else (out, x) -class ParallelFusedDenseGeluDense(nn.Module): +class ParallelFusedMLP(nn.Module): - def __init__(self, in_features, hidden_features, out_features=None, + def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx', process_group: ProcessGroup = None, bias1=True, bias2=True, - sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None): + sequence_parallel=True, checkpoint_lvl=0, heuristic='auto', + device=None, dtype=None): """ process_group is required. We're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul, gelu, then matmul. @@ -440,19 +475,22 @@ def __init__(self, in_features, hidden_features, out_features=None, checkpoint_lvl (increasing lvl means slower but more memory saving): 0: no recomputation in the bwd 1: recompute gelu_out in the bwd - 2: recompute gelu_in and gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd heuristic: -1: don't fuse gemm + gelu (separate kernel) 0..4: use this heuristic for the algo section in the fused gemm + gelu - For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16. + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. """ assert checkpoint_lvl in [0, 1, 2] + assert activation in ['gelu_approx', 'relu'] assert process_group is not None factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() if out_features is None: out_features = in_features + self.activation = activation self.process_group = process_group self.sequence_parallel = sequence_parallel self.checkpoint_lvl = checkpoint_lvl @@ -463,10 +501,19 @@ def __init__(self, in_features, hidden_features, out_features=None, bias=bias2, **factory_kwargs) def forward(self, x): - out = fused_dense_gelu_dense_func( + dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() + if self.heuristic == 'auto': + if self.activation == 'gelu_approx': + cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) + heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + else: + heuristic = 0 + else: + heuristic = self.heuristic + out = fused_mlp_func( x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, - save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl, - heuristic=self.heuristic, + activation=self.activation, save_pre_act=self.training, + checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic, process_group=self.process_group, sequence_parallel=self.sequence_parallel ) diff --git a/tests/models/test_bert.py b/tests/models/test_bert.py index 61525be96..2470ff307 100644 --- a/tests/models/test_bert.py +++ b/tests/models/test_bert.py @@ -95,13 +95,13 @@ def test_bert_optimized(model_name): """ dtype = torch.float16 config = BertConfig.from_pretrained(model_name) - # Our implementation of fused_dense_gelu_dense assumes the activation is + # Our implementation of fused_mlp assumes the activation is # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast". - # If you just want "gelu", disable fused_dense_gelu_dense. + # If you just want "gelu", disable fused_mlp. config.hidden_act = "gelu_new" config.use_flash_attn = True config.fused_bias_fc = True - config.fused_dense_gelu_dense = True + config.fused_mlp = True config.fused_dropout_add_ln = True model = BertForPreTraining.from_pretrained(model_name, config) @@ -171,13 +171,13 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs """ dtype = torch.float16 config = BertConfig.from_pretrained(model_name) - # Our implementation of fused_dense_gelu_dense assumes the activation is + # Our implementation of fused_mlp assumes the activation is # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast". - # If you just want "gelu", disable fused_dense_gelu_dense. + # If you just want "gelu", disable fused_mlp. config.hidden_act = "gelu_new" config.use_flash_attn = True config.fused_bias_fc = True - config.fused_dense_gelu_dense = True + config.fused_mlp = True config.fused_dropout_add_ln = True config.dense_seq_output = True config.last_layer_subset = last_layer_subset diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 2a73c271f..98a996091 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -82,7 +82,7 @@ def test_gpt2_optimized(model_name): vocab_size_og = config.vocab_size config.use_flash_attn = True config.fused_bias_fc = True - config.fused_dense_gelu_dense = True + config.fused_mlp = True config.fused_dropout_add_ln = True config.residual_in_fp32 = True config.pad_vocab_size_multiple = 8 diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index 4ec3082e3..2d1fdc795 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize('fused_ft_kernel', [False, True]) # @pytest.mark.parametrize('fused_ft_kernel', [True]) @pytest.mark.parametrize('optimized', [False, True]) -# @pytest.mark.parametrize('optimized', [True]) +# @pytest.mark.parametrize('optimized', [False]) @pytest.mark.parametrize('rotary', [False, True]) # @pytest.mark.parametrize('rotary', [False]) @pytest.mark.parametrize('model_name', ["gpt2"]) @@ -34,10 +34,11 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): if rotary: config.n_positions = 0 config.rotary_emb_dim = 64 + config.residual_in_fp32 = True if optimized: config.use_flash_attn = True config.fused_bias_fc = True - config.fused_dense_gelu_dense = True + config.fused_mlp = True config.fused_dropout_add_ln = True # if not rotary, we load the weight from HF but ignore the position embeddings. @@ -78,6 +79,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): fused_ft_kernel=fused_ft_kernel, return_dict_in_generate=True, output_scores=True, timing=True) print(out.sequences) + print(tokenizer.batch_decode(out.sequences.tolist())) if fused_ft_kernel: out_cg = model.generate(input_ids=input_ids, max_length=max_length, fused_ft_kernel=fused_ft_kernel, cg=True, @@ -94,122 +96,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') - - assert torch.all(out.sequences == sequences) - assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), - rtol=rtol, atol=atol) - if not rotary: - assert torch.all(out.sequences == out_ref.sequences) - assert torch.all(out.sequences == out_hf.sequences) - - assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() - - -# Run test with: -# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel" - -# @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) -@pytest.mark.parametrize('world_size', [2]) -# @pytest.mark.parametrize('fused_ft_kernel', [False, True]) -@pytest.mark.parametrize('fused_ft_kernel', [True]) -# @pytest.mark.parametrize('rotary', [False, True]) -@pytest.mark.parametrize('rotary', [False]) -@pytest.mark.parametrize('model_name', ["gpt2"]) -def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): - """Check that our implementation of GPT2 generation matches the HF implementation: - the scores in fp16 should be around the same as the HF scores in fp16, when compared to - the HF scores in fp32. - """ - dtype = torch.float16 - rtol, atol = 3e-3, 3e-1 - config = GPT2Config.from_pretrained(model_name) - if rotary: - config.n_positions = 0 - config.rotary_emb_dim = 64 - config.use_flash_attn = True - config.fused_bias_fc = True - config.fused_dense_gelu_dense = True - config.fused_dropout_add_ln = True - config.pad_vocab_size_multiple = 8 * world_size - config.sequence_parallel = False # Need to set this to False for generation - - os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl', init_method='env://') - device = f'cuda:{torch.distributed.get_rank()}' - assert world_size <= torch.distributed.get_world_size() - # Need this, otherwise when we capture the graph the process for GPU 1 would run on both - # GPU0 and GPU1 and things would hang - torch.cuda.set_device(device) - - from apex.transformer import parallel_state - parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) - rank = parallel_state.get_tensor_model_parallel_rank() - process_group = parallel_state.get_tensor_model_parallel_group() - - # if not rotary, we load the weight from HF but ignore the position embeddings. - # The model would be nonsense but it doesn't matter for the test. - model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device, - dtype=dtype, process_group=process_group, - world_size=world_size, rank=rank) - model.eval() - - if not rotary: - model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) - model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype) - model_ref.eval() - model_hf.eval() - - torch.manual_seed(0) - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - input_ids = tokenizer("Hello, my dog is cute and ", - return_tensors="pt").input_ids.to(device=device) - max_length = 30 - # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda') - # max_length = input_ids.shape[1] + 40 - - # Slow generation for reference - sequences = [] - scores = [] - cur_input_ids = input_ids - with torch.inference_mode(): - logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) - logits = rearrange(logits, '(n b) d -> b (n d)', - b=input_ids.shape[0])[..., :config.vocab_size] - scores.append(logits) - sequences.append(scores[-1].argmax(dim=-1)) - for _ in range(input_ids.shape[1] + 1, max_length): - cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) - logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) - logits = rearrange(logits, '(n b) d -> b (n d)', - b=input_ids.shape[0])[..., :config.vocab_size] - scores.append(logits) - sequences.append(scores[-1].argmax(dim=-1)) - sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) - scores = tuple(scores) - print(sequences) - - out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, - vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, - return_dict_in_generate=True, output_scores=True, timing=True) - print(out.sequences) - if fused_ft_kernel: - out_cg = model.generate( - input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, - vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True, - return_dict_in_generate=True, output_scores=True, timing=True) - print(out_cg.sequences) - - if not rotary: - out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, - return_dict_in_generate=True, output_scores=True) - out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, - return_dict_in_generate=True, output_scores=True) - - print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') - print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') - print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') - print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + print(tokenizer.batch_decode(out_ref.sequences.tolist())) assert torch.all(out.sequences == sequences) assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), diff --git a/tests/models/test_gpt_generation_parallel.py b/tests/models/test_gpt_generation_parallel.py new file mode 100644 index 000000000..5817a91df --- /dev/null +++ b/tests/models/test_gpt_generation_parallel.py @@ -0,0 +1,131 @@ +# Run test with: +# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel" +import os +import re + +import torch +import pytest + +from einops import rearrange + +from transformers import GPT2Config, GPT2Tokenizer +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF + +from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.models.gpt import remap_state_dict_gpt2 +from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.distributed import all_gather_raw + + +# @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) +@pytest.mark.parametrize('world_size', [2]) +# @pytest.mark.parametrize('fused_ft_kernel', [False, True]) +@pytest.mark.parametrize('fused_ft_kernel', [True]) +# @pytest.mark.parametrize('rotary', [False, True]) +@pytest.mark.parametrize('rotary', [False]) +@pytest.mark.parametrize('model_name', ["gpt2"]) +def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): + """Check that our implementation of GPT2 generation matches the HF implementation: + the scores in fp16 should be around the same as the HF scores in fp16, when compared to + the HF scores in fp32. + """ + dtype = torch.float16 + rtol, atol = 3e-3, 3e-1 + config = GPT2Config.from_pretrained(model_name) + if rotary: + config.n_positions = 0 + config.rotary_emb_dim = 64 + config.residual_in_fp32 = True + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + config.pad_vocab_size_multiple = 8 * world_size + config.sequence_parallel = False # Need to set this to False for generation + + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + # Need this, otherwise when we capture the graph the process for GPU 1 would run on both + # GPU0 and GPU1 and things would hang + torch.cuda.set_device(device) + + from apex.transformer import parallel_state + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + process_group = parallel_state.get_tensor_model_parallel_group() + + # if not rotary, we load the weight from HF but ignore the position embeddings. + # The model would be nonsense but it doesn't matter for the test. + model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device, + dtype=dtype, process_group=process_group, + world_size=world_size, rank=rank) + model.eval() + + if not rotary: + model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) + model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype) + model_ref.eval() + model_hf.eval() + + torch.manual_seed(0) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + input_ids = tokenizer("Hello, my dog is cute and ", + return_tensors="pt").input_ids.to(device=device) + max_length = 30 + # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda') + # max_length = input_ids.shape[1] + 40 + + # Slow generation for reference + sequences = [] + scores = [] + cur_input_ids = input_ids + with torch.inference_mode(): + logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) + logits = rearrange(logits, '(n b) d -> b (n d)', + b=input_ids.shape[0])[..., :config.vocab_size] + scores.append(logits) + sequences.append(scores[-1].argmax(dim=-1)) + for _ in range(input_ids.shape[1] + 1, max_length): + cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) + logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) + logits = rearrange(logits, '(n b) d -> b (n d)', + b=input_ids.shape[0])[..., :config.vocab_size] + scores.append(logits) + sequences.append(scores[-1].argmax(dim=-1)) + sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) + scores = tuple(scores) + print(sequences) + + out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, + vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, + return_dict_in_generate=True, output_scores=True, timing=True) + print(out.sequences) + if fused_ft_kernel: + out_cg = model.generate( + input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, + vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True, + return_dict_in_generate=True, output_scores=True, timing=True) + print(out_cg.sequences) + + if not rotary: + out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + + print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') + print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') + print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + + assert torch.all(out.sequences == sequences) + assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), + rtol=rtol, atol=atol) + if not rotary: + assert torch.all(out.sequences == out_ref.sequences) + assert torch.all(out.sequences == out_hf.sequences) + + assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() diff --git a/tests/models/test_gpt_parallel.py b/tests/models/test_gpt_parallel.py index dd4451b8b..d222d8885 100644 --- a/tests/models/test_gpt_parallel.py +++ b/tests/models/test_gpt_parallel.py @@ -1,6 +1,8 @@ # Run test with: # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py +import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -59,10 +61,12 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): n_positions=seqlen if has_pos_emb else 0, vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, scale_attn_by_inverse_layer_idx=True, use_flash_attn=True, - fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True, + fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True, + residual_in_fp32=True, rotary_emb_fraction=0.0 if has_pos_emb else 0.5, pad_vocab_size_multiple=8 * world_size, sequence_parallel=sequence_parallel) + config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size) model_pt = GPTLMHeadModel(config, device=device) def init_layer_norm(module): @@ -131,9 +135,9 @@ def init_layer_norm(module): grad_dict['transformer.embeddings.position_embeddings.weight'], rtol=rtol, atol=atol ) - assert torch.allclose(model.transformer.ln_0.weight.grad, grad_dict['transformer.ln_0.weight'], + assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'], rtol=rtol, atol=atol) - assert torch.allclose(model.transformer.ln_0.bias.grad, grad_dict['transformer.ln_0.bias'], + assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'], rtol=rtol, atol=atol) for i in range(num_layers): assert torch.allclose( diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py index 745272107..b1cb959d8 100644 --- a/tests/models/test_vit.py +++ b/tests/models/test_vit.py @@ -8,11 +8,11 @@ from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 -@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True]) -# @pytest.mark.parametrize('fused_dense_gelu_dense', [False]) +@pytest.mark.parametrize('fused_mlp', [False, True]) +# @pytest.mark.parametrize('fused_mlp', [False]) @pytest.mark.parametrize('optimized', [False, True]) # @pytest.mark.parametrize('optimized', [True]) -def test_vit(optimized, fused_dense_gelu_dense): +def test_vit(optimized, fused_mlp): """Check that our implementation of ViT matches the timm's implementation: the output of our forward pass in fp16 should be around the same as timm' forward pass in fp16, when compared to timm's forward pass in fp32. @@ -23,7 +23,7 @@ def test_vit(optimized, fused_dense_gelu_dense): kwargs = {} if optimized: kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) - kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense + kwargs['fused_mlp'] = fused_mlp model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) model_ref = vit_base_patch16_224(pretrained=True).to(device=device) @@ -46,4 +46,5 @@ def test_vit(optimized, fused_dense_gelu_dense): print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}') print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}') - assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item() + rtol = 2 if not fused_mlp else 4 + assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item() diff --git a/tests/modules/test_block_parallel.py b/tests/modules/test_block_parallel.py index 34a9cffe5..f1409a432 100644 --- a/tests/modules/test_block_parallel.py +++ b/tests/modules/test_block_parallel.py @@ -15,7 +15,7 @@ from apex.transformer import tensor_parallel from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import FusedDenseGeluDense, ParallelFusedDenseGeluDense +from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP from flash_attn.modules.block import Block from flash_attn.utils.distributed import allreduce_sequence_parallel_grad @@ -27,7 +27,7 @@ @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('sequence_parallel', [True, False]) -# @pytest.mark.parametrize('sequence_parallel', [False]) +# @pytest.mark.parametrize('sequence_parallel', [True]) @pytest.mark.parametrize('dim', [1024]) def test_block_parallel(dim, sequence_parallel, world_size, dtype): head_dim = 64 @@ -62,8 +62,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, device=device, dtype=dtype) - mlp_cls_pt = partial(FusedDenseGeluDense, hidden_features=4 * dim, - device=device, dtype=dtype) + mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype) norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype) model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True) with torch.no_grad(): @@ -76,7 +75,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): process_group=parallel_state.get_tensor_model_parallel_group(), rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, sequence_parallel=sequence_parallel, device=device, dtype=dtype) - mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim, + mlp_cls = partial(ParallelFusedMLP, hidden_features=4 * dim, process_group=parallel_state.get_tensor_model_parallel_group(), sequence_parallel=sequence_parallel, device=device, dtype=dtype) model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True, @@ -143,7 +142,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] if sequence_parallel else x_pt.grad, - rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small + rtol=rtol, atol=atol / 10 # magnitude of x.grad is quite small ) assert torch.allclose( residual.grad, diff --git a/tests/ops/test_fused_dense.py b/tests/ops/test_fused_dense.py index 5cbf250cb..45c3b4142 100644 --- a/tests/ops/test_fused_dense.py +++ b/tests/ops/test_fused_dense.py @@ -1,4 +1,5 @@ import math +from functools import partial import torch import torch.nn.functional as F @@ -6,7 +7,7 @@ from einops import rearrange -from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense +from flash_attn.ops.fused_dense import FusedDense, FusedMLP @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @@ -60,15 +61,25 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('heuristic', [0, -1]) +# @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize('heuristic', ['auto', -1]) +# @pytest.mark.parametrize('heuristic', ['auto']) @pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2]) +# @pytest.mark.parametrize('checkpoint_lvl', [1]) @pytest.mark.parametrize('return_residual', [False, True]) +# @pytest.mark.parametrize('return_residual', [False]) @pytest.mark.parametrize('has_bias2', [True, False]) @pytest.mark.parametrize('has_bias1', [True, False]) +# @pytest.mark.parametrize('has_bias2', [True]) +# @pytest.mark.parametrize('has_bias1', [True]) +@pytest.mark.parametrize('activation', ['gelu_approx', 'relu']) +# @pytest.mark.parametrize('activation', ['relu']) @pytest.mark.parametrize('out_features', [1024, 4096]) @pytest.mark.parametrize('in_features', [1024, 4096]) -def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, return_residual, - checkpoint_lvl, heuristic, dtype): +# @pytest.mark.parametrize('out_features', [4096]) +# @pytest.mark.parametrize('in_features', [1024]) +def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2, return_residual, + checkpoint_lvl, heuristic, dtype): device = 'cuda' rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) # set seed @@ -82,10 +93,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, dtype=dtype) model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device, dtype=dtype) - model = FusedDenseGeluDense(in_features, out_features, in_features, bias1=has_bias1, - bias2=has_bias2, return_residual=return_residual, - checkpoint_lvl=checkpoint_lvl, heuristic=heuristic, - device=device, dtype=dtype) + model = FusedMLP(in_features, out_features, in_features, activation=activation, + bias1=has_bias1, bias2=has_bias2, return_residual=return_residual, + checkpoint_lvl=checkpoint_lvl, heuristic=heuristic, + device=device, dtype=dtype) with torch.no_grad(): model.fc1.weight.copy_(model_pt_fc1.weight) if has_bias1: @@ -93,7 +104,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, model.fc2.weight.copy_(model_pt_fc2.weight) if has_bias2: model.fc2.bias.copy_(model_pt_fc2.bias) - out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' + else partial(F.relu, inplace=True)) + out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt))) if not return_residual: out = model(x) else: @@ -107,6 +120,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, g = torch.randn_like(out) / 32 out_pt.backward(g) out.backward(g) + # The error for relu is higher still + if activation == 'relu': + atol = 1e-1 if dtype == torch.bfloat16 else 5e-2 assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) # The error for d_weight and d_bias is quite a bit higher assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10) diff --git a/tests/ops/test_fused_dense_parallel.py b/tests/ops/test_fused_dense_parallel.py index 9feff055d..a1830d3b5 100644 --- a/tests/ops/test_fused_dense_parallel.py +++ b/tests/ops/test_fused_dense_parallel.py @@ -10,8 +10,8 @@ from apex.transformer import parallel_state from apex.transformer import tensor_parallel -from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense -from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedDenseGeluDense +from flash_attn.ops.fused_dense import FusedDense, FusedMLP +from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedMLP is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 @@ -106,8 +106,7 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle # @pytest.mark.parametrize('has_bias2', [True]) @pytest.mark.parametrize('out_features', [4096]) @pytest.mark.parametrize('in_features', [1024]) -def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_parallel, - world_size, dtype): +def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype): assert out_features % world_size == 0 rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) if not torch.distributed.is_initialized(): @@ -137,11 +136,11 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_p dtype=dtype) partition_out_features = out_features // world_size partition_in_features = in_features // world_size - model = ParallelFusedDenseGeluDense(in_features, out_features, in_features, - process_group=parallel_state.get_tensor_model_parallel_group(), - bias2=has_bias2 and rank == 0, - sequence_parallel=sequence_parallel, - device=device, dtype=dtype) + model = ParallelFusedMLP(in_features, out_features, in_features, + process_group=parallel_state.get_tensor_model_parallel_group(), + bias2=has_bias2 and rank == 0, + sequence_parallel=sequence_parallel, + device=device, dtype=dtype) with torch.no_grad(): model.fc1.weight.copy_( diff --git a/training/README.md b/training/README.md index b0ee8a33f..6bb429e12 100644 --- a/training/README.md +++ b/training/README.md @@ -48,7 +48,7 @@ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim, n_layer=n_layer, n_head=nheads, scale_attn_by_inverse_layer_idx=True, rotary_emb_fraction=rotary_emb_fraction, - use_flash_attn=True, fused_dense_gelu_dense=True, + use_flash_attn=True, fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True, pad_vocab_size_multiple=8) model = GPTLMHeadModel(config) diff --git a/training/configs/experiment/owt/gpt2s-flash.yaml b/training/configs/experiment/owt/gpt2s-flash.yaml index 0bcd4021f..f2dc6956d 100644 --- a/training/configs/experiment/owt/gpt2s-flash.yaml +++ b/training/configs/experiment/owt/gpt2s-flash.yaml @@ -7,9 +7,10 @@ defaults: model: config: # n_positions is already set to ${datamodule.max_length} + residual_in_fp32: True use_flash_attn: True fused_bias_fc: True - fused_dense_gelu_dense: True + fused_mlp: True fused_dropout_add_ln: True pad_vocab_size_multiple: 8 diff --git a/training/configs/experiment/pile/gpt3s-flash.yaml b/training/configs/experiment/pile/gpt3s-flash.yaml index 3def2a8e8..45302fd4d 100644 --- a/training/configs/experiment/pile/gpt3s-flash.yaml +++ b/training/configs/experiment/pile/gpt3s-flash.yaml @@ -7,9 +7,10 @@ defaults: model: config: # n_positions is already set to ${datamodule.max_length} + residual_in_fp32: True use_flash_attn: True fused_dropout_add_ln: True - fused_dense_gelu_dense: True + fused_mlp: True fused_bias_fc: True pad_vocab_size_multiple: 8