Skip to content

Commit

Permalink
fix some bugs for get_indice_pairs_backward_musa
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhaowen-mt committed Jan 15, 2024
1 parent b2953be commit b5fda52
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 9 deletions.
3 changes: 3 additions & 0 deletions mmcv/ops/csrc/common/pytorch_cpp_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ using namespace at;
#define CHECK_MLU_INPUT(x) \
CHECK_MLU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_MUSA_INPUT(x) \
CHECK_MUSA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_CPU_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
Expand Down
1 change: 0 additions & 1 deletion mmcv/ops/csrc/pytorch/ball_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ void ball_query_forward_impl(int b, int n, int m, float min_radius,
void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
Tensor idx_tensor, int b, int n, int m,
float min_radius, float max_radius, int nsample) {
std::cout<<"ball_query_forward"<<std::endl;
ball_query_forward_impl(b, n, m, min_radius, max_radius, nsample,
new_xyz_tensor, xyz_tensor, idx_tensor);
}
Expand Down
7 changes: 0 additions & 7 deletions mmcv/ops/csrc/pytorch/nms_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const Tensor labels,
const float iou_threshold, const int multi_label) {

std::cout<<"nms_rotated"<<std::endl;
std::cout<<dets<<std::endl;
std::cout<<dets.device()<<std::endl;
std::cout<<dets.is_cuda()<<std::endl;
std::cout<<dets.is_privateuseone()<<std::endl;
assert(dets.is_cuda() == scores.is_cuda());
if (dets.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
return nms_rotated_cuda(dets, scores, order, dets_sorted.contiguous(),
iou_threshold, multi_label);
#else
std::cout<<"nms_rotated in cuda"<<std::endl;
AT_ERROR("Not compiled with GPU support");
#endif
#ifdef MMCV_WITH_XLA
Expand All @@ -64,7 +58,6 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
#endif
#ifdef MMCV_WITH_MUSA
} else if (dets.device().type() == ::at::kPrivateUse1) {
std::cout<<"privateuse1"<<std::endl;
return nms_rotated_musa(dets, scores, order, dets_sorted.contiguous(),
iou_threshold, multi_label);
#endif
Expand Down
33 changes: 33 additions & 0 deletions mmcv/ops/csrc/pytorch/spconv_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ std::vector<torch::Tensor> get_indice_pairs_forward_cuda(
padding, dilation, outPadding, _subM, _transpose);
};



template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsForwardMUSAKernelLauncher(
torch::Tensor indices, int64_t batchSize,
Expand Down Expand Up @@ -97,6 +99,28 @@ std::vector<torch::Tensor> get_indice_pairs_backward_cuda(
stride, padding, dilation, outPadding, _subM, _transpose);
};

#ifdef MMCV_WITH_MUSA
template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsBackwardMUSAKernelLauncher(
torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);

template <unsigned NDim>
std::vector<torch::Tensor> get_indice_pairs_backward_musa(
torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose) {
return GetIndicePairsBackwardMUSAKernelLauncher<NDim>(
indices, gridOut, batchSize, outSpatialShape, spatialShape, kernelSize,
stride, padding, dilation, outPadding, _subM, _transpose);
};
#endif

template <unsigned NDim>
std::vector<torch::Tensor> get_indice_pairs_forward(
torch::Tensor indices, int64_t batchSize,
Expand Down Expand Up @@ -150,6 +174,15 @@ std::vector<torch::Tensor> get_indice_pairs_backward(
AT_ERROR("get_indice_pairs is not compiled with GPU support");
#endif
} else {
#ifdef MMCV_WITH_MUSA
if (indices.device().type() == at::kMUSA) {
CHECK_MUSA_INPUT(indices);
CHECK_MUSA_INPUT(gridOut);
return get_indice_pairs_backward_musa<NDim>(
indices, gridOut, batchSize, outSpatialShape, spatialShape, kernelSize,
stride, padding, dilation, outPadding, _subM, _transpose);
}
#endif
AT_ERROR("get_indice_pairs is not implemented on CPU");
}
}
Expand Down
9 changes: 8 additions & 1 deletion tests/test_ops/test_chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def torch_to_np_type(dtype):
marks=pytest.mark.skipif(
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float32])
@pytest.mark.parametrize('dtype', [
pytest.param(
torch.half,
marks=pytest.mark.skipif(
IS_MUSA_AVAILABLE,
reason='TODO [email protected]: not supported yet')),
torch.float32
])
@pytest.mark.parametrize('shape', [(2, 600, 2), (2, 600, 2)])
def test_chamfer_distance_npu_dynamic_shape(dtype, device, shape):
if device == 'musa':
Expand Down

0 comments on commit b5fda52

Please sign in to comment.