From 99fa5c6da0ddebb1e4b964ee3223720f73f9fa7c Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 4 Sep 2023 20:20:57 +0800 Subject: [PATCH 1/2] Support pre released pytorch2.1.0 (#2865) --- .../common/utils/spconv/tensorview/tensorview.h | 5 +++-- setup.py | 15 +++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h b/mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h index 27745beaa5..66e01a8ed1 100644 --- a/mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h +++ b/mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h @@ -319,8 +319,9 @@ struct ShapeBase : public SimpleVector { TV_HOST_DEVICE_INLINE ShapeBase(std::initializer_list shape) : SimpleVector(shape) {} - template class Container> - ShapeBase(Container shape) : SimpleVector(shape) {} + // TODO: find out why this template can no be used on windows + // template class Container> + // ShapeBase(Container shape) : SimpleVector(shape) {} TV_HOST_DEVICE_INLINE ShapeBase(const ShapeBase &shape) : SimpleVector(shape) {} ShapeBase(const std::vector &arr) : SimpleVector(arr) {} diff --git a/setup.py b/setup.py index 9393d04264..12f3f46b86 100644 --- a/setup.py +++ b/setup.py @@ -267,7 +267,15 @@ def get_extensions(): # to compile those cpp files, so there is no need to add the # argument if platform.system() != 'Windows': - extra_compile_args['cxx'] = ['-std=c++14'] + if parse_version(torch.__version__) <= parse_version('1.12.1'): + extra_compile_args['cxx'] = ['-std=c++14'] + else: + extra_compile_args['cxx'] = ['-std=c++17'] + else: + if parse_version(torch.__version__) <= parse_version('1.12.1'): + extra_compile_args['cxx'] = ['/std:c++14'] + else: + extra_compile_args['cxx'] = ['/std:c++17'] include_dirs = [] @@ -456,7 +464,10 @@ def get_mluops_version(file_path): # to compile those cpp files, so there is no need to add the # argument if 'nvcc' in extra_compile_args and platform.system() != 'Windows': - extra_compile_args['nvcc'] += ['-std=c++14'] + if parse_version(torch.__version__) <= parse_version('1.12.1'): + extra_compile_args['nvcc'] += ['-std=c++14'] + else: + extra_compile_args['nvcc'] += ['-std=c++17'] ext_ops = extension( name=ext_name, From af1c68c70e400cc5dbf37f0a4c8f7630a08a6edf Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 11 Oct 2023 13:08:35 +0800 Subject: [PATCH 2/2] temporarily disable mps ops for torch2.1.0 (#2958) --- setup.py | 8 +++++--- tests/test_ops/test_bbox.py | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 12f3f46b86..a0b6a429cc 100644 --- a/setup.py +++ b/setup.py @@ -424,9 +424,11 @@ def get_mluops_version(file_path): extra_compile_args['cxx'] += ['-ObjC++'] # src op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ - glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ - glob.glob('./mmcv/ops/csrc/common/mps/*.mm') + \ - glob.glob('./mmcv/ops/csrc/pytorch/mps/*.mm') + glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + # TODO: support mps ops on torch>=2.1.0 + if parse_version(torch.__version__) < parse_version('2.1.0'): + op_files += glob.glob('./mmcv/ops/csrc/common/mps/*.mm') + \ + glob.glob('./mmcv/ops/csrc/pytorch/mps/*.mm') extension = CppExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps')) diff --git a/tests/test_ops/test_bbox.py b/tests/test_ops/test_bbox.py index 3d1486eb01..752d877663 100644 --- a/tests/test_ops/test_bbox.py +++ b/tests/test_ops/test_bbox.py @@ -2,6 +2,7 @@ import numpy as np import pytest import torch +from mmengine.utils import digit_version from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE, IS_NPU_AVAILABLE) @@ -56,7 +57,9 @@ def _test_bbox_overlaps(self, device='cpu', dtype=torch.float): pytest.param( 'mps', marks=pytest.mark.skipif( - not IS_MPS_AVAILABLE, reason='requires MPS support')), + not IS_MPS_AVAILABLE + or digit_version(torch.__version__) >= digit_version('2.1.0'), + reason='requires MPS support')), pytest.param( 'npu', marks=pytest.mark.skipif(