From 65a8027519b6a04c42d915b125122373668e5fcf Mon Sep 17 00:00:00 2001 From: nijkah Date: Mon, 14 Aug 2023 07:46:32 +0000 Subject: [PATCH 1/3] [Feature] Support AMP in box_iou_rotated --- mmcv/ops/box_iou_rotated.py | 4 +++ .../csrc/common/cuda/box_iou_rotated_cuda.cuh | 8 ++--- .../csrc/pytorch/cuda/box_iou_rotated_cuda.cu | 13 ++++--- tests/test_ops/test_box_iou_rotated.py | 36 +++++++++++++++++++ 4 files changed, 52 insertions(+), 9 deletions(-) diff --git a/mmcv/ops/box_iou_rotated.py b/mmcv/ops/box_iou_rotated.py index 2443af27c9..1bbbad8d58 100644 --- a/mmcv/ops/box_iou_rotated.py +++ b/mmcv/ops/box_iou_rotated.py @@ -141,6 +141,10 @@ def box_iou_rotated(bboxes1: torch.Tensor, bboxes2 = bboxes2 * flip_mat bboxes1 = bboxes1.contiguous() bboxes2 = bboxes2.contiguous() + + # cast "bboxes2" according to the type of "bboxes1" + bboxes2 = bboxes2.type_as(bboxes1) + ext_module.box_iou_rotated( bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned) if not aligned: diff --git a/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh b/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh index abd47cd854..c55dde36a6 100644 --- a/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh +++ b/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh @@ -28,8 +28,8 @@ __global__ void box_iou_rotated_cuda_kernel( int base1 = b1 * 5; - float block_boxes1[5]; - float block_boxes2[5]; + T block_boxes1[5]; + T block_boxes2[5]; block_boxes1[0] = dev_boxes1[base1 + 0]; block_boxes1[1] = dev_boxes1[base1 + 1]; @@ -55,8 +55,8 @@ __global__ void box_iou_rotated_cuda_kernel( int base1 = b1 * 5; - float block_boxes1[5]; - float block_boxes2[5]; + T block_boxes1[5]; + T block_boxes2[5]; block_boxes1[0] = dev_boxes1[base1 + 0]; block_boxes1[1] = dev_boxes1[base1 + 1]; diff --git a/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu index 3c13e06237..72c9370f25 100644 --- a/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu @@ -16,10 +16,13 @@ void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, at::cuda::CUDAGuard device_guard(boxes1.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - box_iou_rotated_cuda_kernel - <<>>( - num_boxes1, num_boxes2, boxes1.data_ptr(), - boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), - mode_flag, aligned); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + boxes1.scalar_type(), "box_iou_rotated_cuda_kernel", [&] { + box_iou_rotated_cuda_kernel + <<>>( + num_boxes1, num_boxes2, boxes1.data_ptr(), + boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), + mode_flag, aligned); + }); AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/tests/test_ops/test_box_iou_rotated.py b/tests/test_ops/test_box_iou_rotated.py index 9f5e0dfa3e..09c9229cad 100644 --- a/tests/test_ops/test_box_iou_rotated.py +++ b/tests/test_ops/test_box_iou_rotated.py @@ -3,6 +3,16 @@ import pytest import torch +from mmcv.ops import box_iou_rotated + +try: + # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast + # would be imported and used; we should test if our modules support it. + from torch.cuda.amp import autocast + _IS_AUTOCAST_AVAILABLE = True +except ImportError: + _IS_AUTOCAST_AVAILABLE = False + class TestBoxIoURotated: @@ -84,6 +94,32 @@ def test_box_iou_rotated_cuda(self): assert np.allclose( ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4) + @pytest.mark.skipif( + not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support') + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') + def test_box_iou_rotated_with_autocast(self, device): + np_boxes1 = np.asarray( + [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], + [7.0, 7.0, 8.0, 8.0, 0.4]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5], + [5.0, 5.0, 6.0, 7.0, 0.4]], + dtype=np.float32) + np_expect_ious = np.asarray( + [[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424], + [0.0000, 0.0000, 0.3622]], + dtype=np.float16) + + boxes1 = torch.from_numpy(np_boxes1).to(device).type(torch.half) + boxes2 = torch.from_numpy(np_boxes2).to(device) + + with autocast(enabled=True): + # test cw angle definition + ious = box_iou_rotated(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-2) + def test_box_iou_rotated_iof_cpu(self): from mmcv.ops import box_iou_rotated np_boxes1 = np.asarray( From 40a60c6ebe1922f8dbd1421877cfd21e6e1f7ed6 Mon Sep 17 00:00:00 2001 From: nijkah Date: Wed, 16 Aug 2023 06:59:02 +0000 Subject: [PATCH 2/3] fix import --- tests/test_ops/test_box_iou_rotated.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_ops/test_box_iou_rotated.py b/tests/test_ops/test_box_iou_rotated.py index 09c9229cad..30354f36fb 100644 --- a/tests/test_ops/test_box_iou_rotated.py +++ b/tests/test_ops/test_box_iou_rotated.py @@ -3,8 +3,6 @@ import pytest import torch -from mmcv.ops import box_iou_rotated - try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast # would be imported and used; we should test if our modules support it. @@ -99,6 +97,7 @@ def test_box_iou_rotated_cuda(self): @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') def test_box_iou_rotated_with_autocast(self, device): + from mmcv.ops import box_iou_rotated np_boxes1 = np.asarray( [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], [7.0, 7.0, 8.0, 8.0, 0.4]], From 15feb5978b942ed42b2b3d312b9cdd1ccfd5d8a2 Mon Sep 17 00:00:00 2001 From: nijkah Date: Wed, 1 Nov 2023 15:31:38 +0900 Subject: [PATCH 3/3] fix lint --- mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu index 72c9370f25..6b3eb06aa8 100644 --- a/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu @@ -17,12 +17,12 @@ void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, at::cuda::CUDAGuard device_guard(boxes1.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - boxes1.scalar_type(), "box_iou_rotated_cuda_kernel", [&] { + boxes1.scalar_type(), "box_iou_rotated_cuda_kernel", [&] { box_iou_rotated_cuda_kernel <<>>( num_boxes1, num_boxes2, boxes1.data_ptr(), - boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), - mode_flag, aligned); - }); + boxes2.data_ptr(), + (scalar_t*)ious.data_ptr(), mode_flag, aligned); + }); AT_CUDA_CHECK(cudaGetLastError()); }