diff --git a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp index 10c2794e5b..45e668440f 100644 --- a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp @@ -22,6 +22,9 @@ using namespace at; #define CHECK_MLU_INPUT(x) \ CHECK_MLU(x); \ CHECK_CONTIGUOUS(x) +#define CHECK_MUSA_INPUT(x) \ + CHECK_MUSA(x); \ + CHECK_CONTIGUOUS(x) #define CHECK_CPU_INPUT(x) \ CHECK_CPU(x); \ CHECK_CONTIGUOUS(x) diff --git a/mmcv/ops/csrc/pytorch/ball_query.cpp b/mmcv/ops/csrc/pytorch/ball_query.cpp index 7b56568338..cff552b8ea 100644 --- a/mmcv/ops/csrc/pytorch/ball_query.cpp +++ b/mmcv/ops/csrc/pytorch/ball_query.cpp @@ -16,7 +16,6 @@ void ball_query_forward_impl(int b, int n, int m, float min_radius, void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, Tensor idx_tensor, int b, int n, int m, float min_radius, float max_radius, int nsample) { - std::cout<<"ball_query_forward"< get_indice_pairs_forward_cuda( padding, dilation, outPadding, _subM, _transpose); }; + + template std::vector GetIndicePairsForwardMUSAKernelLauncher( torch::Tensor indices, int64_t batchSize, @@ -97,6 +99,28 @@ std::vector get_indice_pairs_backward_cuda( stride, padding, dilation, outPadding, _subM, _transpose); }; +#ifdef MMCV_WITH_MUSA +template +std::vector GetIndicePairsBackwardMUSAKernelLauncher( + torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); + +template +std::vector get_indice_pairs_backward_musa( + torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose) { + return GetIndicePairsBackwardMUSAKernelLauncher( + indices, gridOut, batchSize, outSpatialShape, spatialShape, kernelSize, + stride, padding, dilation, outPadding, _subM, _transpose); +}; +#endif + template std::vector get_indice_pairs_forward( torch::Tensor indices, int64_t batchSize, @@ -150,6 +174,15 @@ std::vector get_indice_pairs_backward( AT_ERROR("get_indice_pairs is not compiled with GPU support"); #endif } else { +#ifdef MMCV_WITH_MUSA + if (indices.device().type() == at::kMUSA) { + CHECK_MUSA_INPUT(indices); + CHECK_MUSA_INPUT(gridOut); + return get_indice_pairs_backward_musa( + indices, gridOut, batchSize, outSpatialShape, spatialShape, kernelSize, + stride, padding, dilation, outPadding, _subM, _transpose); + } +#endif AT_ERROR("get_indice_pairs is not implemented on CPU"); } } diff --git a/tests/test_ops/test_chamfer_distance.py b/tests/test_ops/test_chamfer_distance.py index 99f4627166..418294fcb5 100644 --- a/tests/test_ops/test_chamfer_distance.py +++ b/tests/test_ops/test_chamfer_distance.py @@ -57,7 +57,14 @@ def torch_to_np_type(dtype): marks=pytest.mark.skipif( not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) -@pytest.mark.parametrize('dtype', [torch.half, torch.float32]) +@pytest.mark.parametrize('dtype', [ + pytest.param( + torch.half, + marks=pytest.mark.skipif( + IS_MUSA_AVAILABLE, + reason='TODO haowen.han@mthreads.com: not supported yet')), + torch.float32 +]) @pytest.mark.parametrize('shape', [(2, 600, 2), (2, 600, 2)]) def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape): if device == 'musa':