Skip to content

Commit

Permalink
fix some bugs for get_indice_pairs_backward_musa
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhaowen-mt committed Jan 12, 2024
1 parent b2953be commit fb7b89f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
3 changes: 3 additions & 0 deletions mmcv/ops/csrc/common/pytorch_cpp_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ using namespace at;
#define CHECK_MLU_INPUT(x) \
CHECK_MLU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_MUSA_INPUT(x) \
CHECK_MUSA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_CPU_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
Expand Down
33 changes: 33 additions & 0 deletions mmcv/ops/csrc/pytorch/spconv_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ std::vector<torch::Tensor> get_indice_pairs_forward_cuda(
padding, dilation, outPadding, _subM, _transpose);
};



template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsForwardMUSAKernelLauncher(
torch::Tensor indices, int64_t batchSize,
Expand Down Expand Up @@ -97,6 +99,28 @@ std::vector<torch::Tensor> get_indice_pairs_backward_cuda(
stride, padding, dilation, outPadding, _subM, _transpose);
};

#ifdef MMCV_WITH_MUSA
template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsBackwardMUSAKernelLauncher(
torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);

template <unsigned NDim>
std::vector<torch::Tensor> get_indice_pairs_backward_musa(
torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose) {
return GetIndicePairsBackwardMUSAKernelLauncher<NDim>(
indices, gridOut, batchSize, outSpatialShape, spatialShape, kernelSize,
stride, padding, dilation, outPadding, _subM, _transpose);
};
#endif

template <unsigned NDim>
std::vector<torch::Tensor> get_indice_pairs_forward(
torch::Tensor indices, int64_t batchSize,
Expand Down Expand Up @@ -150,6 +174,15 @@ std::vector<torch::Tensor> get_indice_pairs_backward(
AT_ERROR("get_indice_pairs is not compiled with GPU support");
#endif
} else {
#ifdef MMCV_WITH_MUSA
if (indices.device().type() == at::kMUSA) {
CHECK_MUSA_INPUT(indices);
CHECK_MUSA_INPUT(gridOut);
return get_indice_pairs_backward_musa<NDim>(
indices, gridOut, batchSize, outSpatialShape, spatialShape, kernelSize,
stride, padding, dilation, outPadding, _subM, _transpose);
}
#endif
AT_ERROR("get_indice_pairs is not implemented on CPU");
}
}
Expand Down
9 changes: 8 additions & 1 deletion tests/test_ops/test_chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def torch_to_np_type(dtype):
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float32])
@pytest.mark.parametrize('dtype', [
pytest.param(
torch.half,
marks=pytest.mark.skipif(
IS_MUSA_AVAILABLE,
reason='TODO [email protected]: not supported yet')),
torch.float32
])
@pytest.mark.parametrize('shape', [(2, 600, 2), (2, 600, 2)])
def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape):
if device == 'musa':
Expand Down

0 comments on commit fb7b89f

Please sign in to comment.