diff --git a/mmcv/ops/box_iou_rotated.py b/mmcv/ops/box_iou_rotated.py index a811531d42..eea724dece 100644 --- a/mmcv/ops/box_iou_rotated.py +++ b/mmcv/ops/box_iou_rotated.py @@ -149,6 +149,10 @@ def box_iou_rotated(bboxes1: torch.Tensor, bboxes2 = bboxes2 * scale_mat bboxes1 = bboxes1.contiguous() bboxes2 = bboxes2.contiguous() + + # cast "bboxes2" according to the type of "bboxes1" + bboxes2 = bboxes2.type_as(bboxes1) + ext_module.box_iou_rotated( bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned) if not aligned: diff --git a/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh b/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh index abd47cd854..c55dde36a6 100644 --- a/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh +++ b/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh @@ -28,8 +28,8 @@ __global__ void box_iou_rotated_cuda_kernel( int base1 = b1 * 5; - float block_boxes1[5]; - float block_boxes2[5]; + T block_boxes1[5]; + T block_boxes2[5]; block_boxes1[0] = dev_boxes1[base1 + 0]; block_boxes1[1] = dev_boxes1[base1 + 1]; @@ -55,8 +55,8 @@ __global__ void box_iou_rotated_cuda_kernel( int base1 = b1 * 5; - float block_boxes1[5]; - float block_boxes2[5]; + T block_boxes1[5]; + T block_boxes2[5]; block_boxes1[0] = dev_boxes1[base1 + 0]; block_boxes1[1] = dev_boxes1[base1 + 1]; diff --git a/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu index 3c13e06237..6b3eb06aa8 100644 --- a/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu @@ -16,10 +16,13 @@ void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, at::cuda::CUDAGuard device_guard(boxes1.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - box_iou_rotated_cuda_kernel - <<>>( - num_boxes1, num_boxes2, boxes1.data_ptr(), - boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), - mode_flag, aligned); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + boxes1.scalar_type(), "box_iou_rotated_cuda_kernel", [&] { + box_iou_rotated_cuda_kernel + <<>>( + num_boxes1, num_boxes2, boxes1.data_ptr(), + boxes2.data_ptr(), + (scalar_t*)ious.data_ptr(), mode_flag, aligned); + }); AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/tests/test_ops/test_box_iou_rotated.py b/tests/test_ops/test_box_iou_rotated.py index 3af811d0fe..4467bbe5b9 100644 --- a/tests/test_ops/test_box_iou_rotated.py +++ b/tests/test_ops/test_box_iou_rotated.py @@ -3,6 +3,13 @@ import pytest import torch +try: + # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast + # would be imported and used; we should test if our modules support it. + from torch.cuda.amp import autocast + _IS_AUTOCAST_AVAILABLE = True +except ImportError: + _IS_AUTOCAST_AVAILABLE = False from mmcv.ops import box_iou_rotated from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE @@ -97,6 +104,33 @@ def test_box_iou_rotated(self, device): assert np.allclose( ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4) + @pytest.mark.skipif( + not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support') + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') + def test_box_iou_rotated_with_autocast(self, device): + from mmcv.ops import box_iou_rotated + np_boxes1 = np.asarray( + [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], + [7.0, 7.0, 8.0, 8.0, 0.4]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5], + [5.0, 5.0, 6.0, 7.0, 0.4]], + dtype=np.float32) + np_expect_ious = np.asarray( + [[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424], + [0.0000, 0.0000, 0.3622]], + dtype=np.float16) + + boxes1 = torch.from_numpy(np_boxes1).to(device).type(torch.half) + boxes2 = torch.from_numpy(np_boxes2).to(device) + + with autocast(enabled=True): + # test cw angle definition + ious = box_iou_rotated(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-2) + def test_box_iou_rotated_iof_cpu(self): np_boxes1 = np.asarray( [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],