diff --git a/tests/cpp/test_persistent_buffer.cpp b/tests/cpp/test_persistent_buffer.cpp index a08650d712f..97c46626109 100644 --- a/tests/cpp/test_persistent_buffer.cpp +++ b/tests/cpp/test_persistent_buffer.cpp @@ -20,7 +20,7 @@ namespace nvfuser { using testing::Contains; -using testing::UnorderedElementsAre; +using testing::ElementsAre; using PersistentBufferTest = NVFuserTest; TEST_F(PersistentBufferTest, FusionPersistentBufferCalculation1_CUDA) { @@ -1472,18 +1472,23 @@ TEST_P(LayerNormSharedMemoryTest, FusionLayerNormSharedMemoryBuffer_CUDA) { constexpr int64_t dim0 = 2048; std::vector input_shape{dim0, hidden_size}; std::vector norm_shape{hidden_size}; - auto input_half = makeContigTensor(2, dtype); - auto weight_half = makeContigTensor(1, dtype); - auto bias_half = makeContigTensor(1, dtype); - fusion.addInput(input_half); - fusion.addInput(weight_half); - fusion.addInput(bias_half); - auto input = castOp(DataType::Float, input_half); - auto weight = castOp(DataType::Float, weight_half); - auto bias = castOp(DataType::Float, bias_half); + + auto input = makeContigTensor(2, dtype); + auto weight = makeContigTensor(1, dtype); + auto bias = makeContigTensor(1, dtype); + fusion.addInput(input); + fusion.addInput(weight); + fusion.addInput(bias); + if (dtype == DataType::Half) { + input = castOp(DataType::Float, input); + weight = castOp(DataType::Float, weight); + bias = castOp(DataType::Float, bias); + } auto result = layer_norm(input, norm_shape, weight, bias, eps_ptr); - auto result_output = castOp(dtype, result.output); - fusion.addOutput(result_output); + if (dtype == DataType::Half) { + result.output = castOp(DataType::Half, result.output); + } + fusion.addOutput(result.output); fusion.addOutput(result.mean); fusion.addOutput(result.invstd); @@ -1534,18 +1539,9 @@ TEST_P(LayerNormSharedMemoryTest, FusionLayerNormSharedMemoryBuffer_CUDA) { auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); auto runtime = executor_cache.getMostRecentKernelRuntime(); if (has_enough_regs_smem) { - // For dtype float, no op scheduler is also used. - if (dtype == DataType::Float) { - EXPECT_THAT( - runtime->fusionSegments()->groups(), - UnorderedElementsAre( - HeuristicIs(SchedulerType::NoOp), - HeuristicIs(SchedulerType::InnerPersistent))); - } else { - EXPECT_THAT( - runtime->fusionSegments()->groups(), - UnorderedElementsAre(HeuristicIs(SchedulerType::InnerPersistent))); - } + EXPECT_THAT( + runtime->fusionSegments()->groups(), + ElementsAre(HeuristicIs(SchedulerType::InnerPersistent))); Fusion* scheduled_fusion = runtime->executors() .back() ->as()