Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support AMP in box_iou_rotated #2899

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mmcv/ops/box_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def box_iou_rotated(bboxes1: torch.Tensor,
bboxes2 = bboxes2 * scale_mat
bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous()

# cast "bboxes2" according to the type of "bboxes1"
bboxes2 = bboxes2.type_as(bboxes1)

ext_module.box_iou_rotated(
bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
if not aligned:
Expand Down
8 changes: 4 additions & 4 deletions mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ __global__ void box_iou_rotated_cuda_kernel(

int base1 = b1 * 5;

float block_boxes1[5];
float block_boxes2[5];
T block_boxes1[5];
T block_boxes2[5];

block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
Expand All @@ -55,8 +55,8 @@ __global__ void box_iou_rotated_cuda_kernel(

int base1 = b1 * 5;

float block_boxes1[5];
float block_boxes2[5];
T block_boxes1[5];
T block_boxes2[5];

block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
Expand Down
13 changes: 8 additions & 5 deletions mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,

at::cuda::CUDAGuard device_guard(boxes1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
box_iou_rotated_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
mode_flag, aligned);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
boxes1.scalar_type(), "box_iou_rotated_cuda_kernel", [&] {
box_iou_rotated_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(),
(scalar_t*)ious.data_ptr<scalar_t>(), mode_flag, aligned);
});
AT_CUDA_CHECK(cudaGetLastError());
}
34 changes: 34 additions & 0 deletions tests/test_ops/test_box_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
import pytest
import torch

try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
from torch.cuda.amp import autocast
_IS_AUTOCAST_AVAILABLE = True
except ImportError:
_IS_AUTOCAST_AVAILABLE = False
from mmcv.ops import box_iou_rotated
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE

Expand Down Expand Up @@ -97,6 +104,33 @@ def test_box_iou_rotated(self, device):
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)

@pytest.mark.skipif(
not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support')
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_box_iou_rotated_with_autocast(self, device):
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
dtype=np.float32)
np_boxes2 = np.asarray(
[[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
[5.0, 5.0, 6.0, 7.0, 0.4]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424],
[0.0000, 0.0000, 0.3622]],
dtype=np.float16)

boxes1 = torch.from_numpy(np_boxes1).to(device).type(torch.half)
boxes2 = torch.from_numpy(np_boxes2).to(device)

with autocast(enabled=True):
# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-2)

def test_box_iou_rotated_iof_cpu(self):
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
Expand Down
Loading