Skip to content

Commit

Permalink
remove redundant cast
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiangxl committed Feb 7, 2025
1 parent b6e1530 commit 1d25565
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions tests/cpp/test_persistent_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace nvfuser {

using testing::Contains;
using testing::UnorderedElementsAre;
using testing::ElementsAre;
using PersistentBufferTest = NVFuserTest;

TEST_F(PersistentBufferTest, FusionPersistentBufferCalculation1_CUDA) {
Expand Down Expand Up @@ -1472,18 +1472,23 @@ TEST_P(LayerNormSharedMemoryTest, FusionLayerNormSharedMemoryBuffer_CUDA) {
constexpr int64_t dim0 = 2048;
std::vector<int64_t> input_shape{dim0, hidden_size};
std::vector<int64_t> norm_shape{hidden_size};
auto input_half = makeContigTensor(2, dtype);
auto weight_half = makeContigTensor(1, dtype);
auto bias_half = makeContigTensor(1, dtype);
fusion.addInput(input_half);
fusion.addInput(weight_half);
fusion.addInput(bias_half);
auto input = castOp(DataType::Float, input_half);
auto weight = castOp(DataType::Float, weight_half);
auto bias = castOp(DataType::Float, bias_half);

auto input = makeContigTensor(2, dtype);
auto weight = makeContigTensor(1, dtype);
auto bias = makeContigTensor(1, dtype);
fusion.addInput(input);
fusion.addInput(weight);
fusion.addInput(bias);
if (dtype == DataType::Half) {
input = castOp(DataType::Float, input);
weight = castOp(DataType::Float, weight);
bias = castOp(DataType::Float, bias);
}
auto result = layer_norm(input, norm_shape, weight, bias, eps_ptr);
auto result_output = castOp(dtype, result.output);
fusion.addOutput(result_output);
if (dtype == DataType::Half) {
result.output = castOp(DataType::Half, result.output);
}
fusion.addOutput(result.output);
fusion.addOutput(result.mean);
fusion.addOutput(result.invstd);

Expand Down Expand Up @@ -1534,18 +1539,9 @@ TEST_P(LayerNormSharedMemoryTest, FusionLayerNormSharedMemoryBuffer_CUDA) {
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
auto runtime = executor_cache.getMostRecentKernelRuntime();
if (has_enough_regs_smem) {
// For dtype float, no op scheduler is also used.
if (dtype == DataType::Float) {
EXPECT_THAT(
runtime->fusionSegments()->groups(),
UnorderedElementsAre(
HeuristicIs(SchedulerType::NoOp),
HeuristicIs(SchedulerType::InnerPersistent)));
} else {
EXPECT_THAT(
runtime->fusionSegments()->groups(),
UnorderedElementsAre(HeuristicIs(SchedulerType::InnerPersistent)));
}
EXPECT_THAT(
runtime->fusionSegments()->groups(),
ElementsAre(HeuristicIs(SchedulerType::InnerPersistent)));
Fusion* scheduled_fusion = runtime->executors()
.back()
->as<KernelExecutor>()
Expand Down

0 comments on commit 1d25565

Please sign in to comment.