Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not fuse resize-based ops and index ops (yet) #3845

Merged
merged 3 commits into from
Feb 7, 2025

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Feb 6, 2025

Fixes #3718

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 6, 2025

!test

Copy link

github-actions bot commented Feb 6, 2025

Review updated until commit abcafb5

Description

  • Prevents fusing resize-based ops with index ops in scheduler

  • Adds utility functions to detect resize and index ops

  • Includes a test case for the new behavior


Changes walkthrough 📝

Relevant files
Enhancement
registry.cpp
Added check for resize and index ops                                         

csrc/scheduler/registry.cpp

  • Added check to prevent scheduling if fusion has both resize-based and
    index ops
  • +6/-0     
    registry_utils.cpp
    Implemented resize and index op detection                               

    csrc/scheduler/registry_utils.cpp

  • Implemented hasResizeAndIndexOps to detect resize and index ops in
    fusion
  • +21/-0   
    resize.cpp
    Updated resize scheduler check                                                     

    csrc/scheduler/resize.cpp

  • Updated canScheduleCompileTime to use
    scheduler_tools::hasResizeBasedOps
  • +1/-1     
    resize_utils.cpp
    Added resize op utility functions                                               

    csrc/scheduler/tools/resize_utils.cpp

    • Added isResizeBasedOp and hasResizeBasedOps functions
    +8/-0     
    registry_utils.h
    Declared resize and index op detection function                   

    csrc/scheduler/registry_utils.h

    • Declared hasResizeAndIndexOps in SchedulerTopologyChecker
    +2/-0     
    resize_utils.h
    Declared resize op utility functions                                         

    csrc/scheduler/tools/resize_utils.h

    • Declared isResizeBasedOp and hasResizeBasedOps functions
    +4/-0     
    Tests
    test_resize.cpp
    Added test for resize and index op separation                       

    tests/cpp/test_resize.cpp

    • Added test case to verify separation of resize and index ops
    +51/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Impact

    The new check for resize and index ops may lead to additional overhead in the scheduling process. It is important to verify that this does not negatively impact performance.

    if (registry_utils::SchedulerTopologyChecker::hasResizeAndIndexOps(fusion)) {
      scheduler_debug_utils::canScheduleRejectReason(
          scheduler_type, "has resize-based ops and index ops");
      return false;
    }
    Code Duplication

    The function isResizeBasedOp is defined in both resize_utils.cpp and registry_utils.cpp. Consider consolidating this function to avoid duplication.

    bool isResizeBasedOp(Expr* expr) {
      return expr->isOneOf<SliceOp, PadOp>();
    }
    Test Coverage

    Ensure that the new test covers all edge cases and scenarios where resize and index ops are mixed. Additional tests may be necessary to validate the behavior.

    // Mixing resize and index ops is not supported yet.Specifically,
    // resize requires TensorIndexer, which is based on IdModel, but index
    // ops like take_along_axis is not yet supported by IdModel.
    TEST_F(ResizeTest, DoNotFuseResizeAndIndexOps) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      auto tv0 = makeContigConcreteTensor({128, 4095});
      fusion.addInput(tv0);
      auto tv1 = makeContigConcreteTensor({1, 4096}, DataType::Int);
      fusion.addInput(tv1);
      auto tv2 = slice(
          tv1,
          {{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(1L)},
           {IrBuilder::create<Val>(1L), IrBuilder::create<Val>(4096)}});
      auto tv3 = takeAlongAxis(tv0, tv2, 0);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
      auto t0 = at::randn({128, 4095}, options);
      auto t1 = at::randint(0, 128, {1, 4096}, options_int);
      std::vector<c10::IValue> inputs({t0, t1});
    
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto outputs = executor_cache.runFusionWithInputs(inputs);
      testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__);
    
      FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
    
      EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2)
          << "Unexpected segmentation";
    
      // Make sure two ops are separated into their own segments
      for (auto segmented_group : runtime->fusionSegments()->groups()) {
        bool has_resize = false;
        bool has_index_op = false;
        for (auto expr : segmented_group->exprs()) {
          if (scheduler_tools::isResizeBasedOp(expr)) {
            has_resize = true;
          } else if (
              expr->isOneOf<TorchGatherOp, ScatterOp, IndexSelectOp, SelectOp>()) {
            has_index_op = true;
          }
        }
    
        EXPECT_NE(has_resize, has_index_op);
      }
    }

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 6, 2025

    !test

    @naoyam naoyam requested a review from jjsjann123 February 6, 2025 23:22
    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM. cc'ing @protonu as well, since we'll be touching/changing this likely.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 7, 2025

    !build

    @naoyam naoyam merged commit b6e1530 into main Feb 7, 2025
    16 of 17 checks passed
    @naoyam naoyam deleted the dont_fuse_resize_and_index_ops branch February 7, 2025 20:51
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    BFS traversal could not visit some nodes while fusing take_along_axis
    2 participants