Skip to content

Commit

Permalink
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jan 18, 2023
1 parent 780e8ee commit 88173a1
Show file tree
Hide file tree
Showing 20 changed files with 654 additions and 779 deletions.
86 changes: 41 additions & 45 deletions csrc/fused_dense_lib/fused_dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@
}

template <typename T>
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias);

template <typename T>
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act);

template <typename T>
int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace);
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias);

std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {

int batch_size = input.size(0);
int in_features = input.size(1);
int out_features = d_output.size(1);
int64_t batch_size = input.size(0);
int64_t in_features = input.size(1);
int64_t out_features = d_output.size(1);

TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == d_output.dtype());
Expand All @@ -66,8 +66,6 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
d_bias = at::empty({out_features}, opts);
#endif
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);

DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
auto result = linear_bias_wgrad_cuda<scalar_t>(
Expand All @@ -77,21 +75,20 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
(void*) (lt_workspace.data_ptr<scalar_t>()));
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr);
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
});

return {d_weight, d_bias};
}

std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
c10::optional<at::Tensor> bias_,
bool save_gelu_in, int heuristic) {
std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
c10::optional<at::Tensor> bias_,
bool is_gelu, bool save_pre_act, int heuristic) {

int batch_size = input.size(0);
int in_features = input.size(1);
int out_features = weight.size(0);
int64_t batch_size = input.size(0);
int64_t in_features = input.size(1);
int64_t out_features = weight.size(0);

TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == weight.dtype());
Expand All @@ -116,51 +113,52 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
// create output/workspace tensor
auto opts = input.options();
auto output = at::empty({batch_size, out_features}, opts);
at::Tensor gelu_in;
if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
at::Tensor pre_act;
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
is_gelu ? opts : opts.dtype(torch::kUInt8)); }

DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
auto result = linear_gelu_forward_cuda<scalar_t>(
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
auto result = linear_act_forward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
in_features,
batch_size,
out_features,
is_gelu,
heuristic,
output.data_ptr<scalar_t>(),
save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_gelu_forward failed.");
save_pre_act ? pre_act.data_ptr() : nullptr);
TORCH_CHECK(result == 0, "linear_act_forward failed.");
});

std::vector<at::Tensor> result = {output};
if (save_gelu_in) { result.push_back(gelu_in); };
if (save_pre_act) { result.push_back(pre_act); };
return result;
}

std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic
std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic
) {

int batch_size = d_output.size(0);
int out_features = d_output.size(1);
int in_features = weight.size(1);
int64_t batch_size = d_output.size(0);
int64_t out_features = d_output.size(1);
int64_t in_features = weight.size(1);

TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
TORCH_CHECK(weight.dtype() == d_output.dtype());
TORCH_CHECK(weight.dtype() == gelu_in.dtype());
TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8));
TORCH_CHECK(weight.is_cuda());
TORCH_CHECK(d_output.is_cuda());
TORCH_CHECK(gelu_in.is_cuda());
TORCH_CHECK(pre_act.is_cuda());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(d_output.is_contiguous());
TORCH_CHECK(gelu_in.is_contiguous());
TORCH_CHECK(pre_act.is_contiguous());
CHECK_SHAPE(weight, out_features, in_features);
CHECK_SHAPE(d_output, batch_size, out_features);
CHECK_SHAPE(gelu_in, batch_size, in_features);
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand All @@ -170,29 +168,27 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
auto opts = weight.options();
auto d_bias = at::empty({in_features}, opts);
auto d_input = at::empty({batch_size, in_features}, opts);
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);

DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] {
auto result = bias_gelu_linear_dgrad_bgrad_cuda<scalar_t>(
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
weight.data_ptr<scalar_t>(),
d_output.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
pre_act.data_ptr(),
in_features,
batch_size,
out_features,
is_gelu,
heuristic,
d_input.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed.");
d_bias.data_ptr<scalar_t>());
TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
});

return {d_input, d_bias};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad");
m.def("linear_act_forward", &linear_act_forward, "linear gelu/relu forward");
m.def("bias_act_linear_dgrad_bgrad", &bias_act_linear_dgrad_bgrad, "bias gelu/relu linear dgrad bgrad");
}
Loading

0 comments on commit 88173a1

Please sign in to comment.