-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b5fda52
commit 5c4912b
Showing
8 changed files
with
33 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
import numpy as np | ||
import pytest | ||
import torch | ||
from mmengine.device import is_musa_available | ||
|
||
from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d | ||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE | ||
|
@@ -11,6 +10,7 @@ | |
torch.backends.mlu.matmul.allow_tf32 = False | ||
|
||
|
||
# TODO [email protected] there are some bugs for musa! | ||
@pytest.mark.parametrize('device', [ | ||
pytest.param( | ||
'cuda', | ||
|
@@ -40,6 +40,7 @@ def test_diff_iou_rotated_2d(device): | |
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) | ||
|
||
|
||
# TODO [email protected] there are some bugs for musa! | ||
@pytest.mark.parametrize('device', [ | ||
pytest.param( | ||
'cuda', | ||
|
@@ -68,48 +69,3 @@ def test_diff_iou_rotated_3d(device): | |
np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]]) | ||
ious = diff_iou_rotated_3d(boxes1, boxes2) | ||
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) | ||
|
||
|
||
@pytest.mark.skipif( | ||
is_musa_available(), | ||
reason='TODO [email protected] there are some bugs!') | ||
def test_diff_iou_rotated_2d_musa(): | ||
np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], | ||
[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], | ||
[0.5, 0.5, 1., 1., .0]]], | ||
dtype=np.float32) | ||
np_boxes2 = np.asarray( | ||
[[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., np.pi / 2], | ||
[0.5, 0.5, 1., 1., np.pi / 4], [1., 1., 1., 1., .0], | ||
[1.5, 1.5, 1., 1., .0]]], | ||
dtype=np.float32) | ||
|
||
boxes1 = torch.from_numpy(np_boxes1).musa() | ||
boxes2 = torch.from_numpy(np_boxes2).musa() | ||
|
||
np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]]) | ||
ious = diff_iou_rotated_2d(boxes1, boxes2) | ||
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-3) | ||
|
||
|
||
@pytest.mark.skipif( | ||
is_musa_available(), | ||
reason='TODO [email protected] there are some bugs!') | ||
def test_diff_iou_rotated_3d_musa(): | ||
np_boxes1 = np.asarray( | ||
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], | ||
[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], | ||
[.5, .5, .5, 1., 1., 1., .0]]], | ||
dtype=np.float32) | ||
np_boxes2 = np.asarray( | ||
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 2., np.pi / 2], | ||
[.5, .5, .5, 1., 1., 1., np.pi / 4], [1., 1., 1., 1., 1., 1., .0], | ||
[-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]], | ||
dtype=np.float32) | ||
|
||
boxes1 = torch.from_numpy(np_boxes1).musa() | ||
boxes2 = torch.from_numpy(np_boxes2).musa() | ||
|
||
np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]]) | ||
ious = diff_iou_rotated_3d(boxes1, boxes2) | ||
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -225,7 +225,8 @@ def test_filtered_lrelu_cuda(self): | |
assert out.shape == (1, 3, 16, 16) | ||
|
||
@pytest.mark.skipif( | ||
is_musa_available(), | ||
True, | ||
# not is_musa_available(), | ||
reason='TODO [email protected]: not supported yet') | ||
def test_filtered_lrelu_musa(self): | ||
out = filtered_lrelu(self.input_tensor.musa(), bias=self.bias.musa()) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,8 @@ class TestNMSQuadri: | |
pytest.param( | ||
'musa', | ||
marks=pytest.mark.skipif( | ||
IS_MUSA_AVAILABLE, | ||
True, | ||
# not IS_MUSA_AVAILABLE, | ||
reason='TODO [email protected]:not supported yet!')), | ||
]) | ||
def test_ml_nms_quadri(self, device): | ||
|
@@ -43,17 +44,13 @@ def test_ml_nms_quadri(self, device): | |
assert np.allclose(dets.cpu().numpy()[:, :8], np_expect_dets) | ||
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds) | ||
|
||
# TODO:[email protected] musa not supported yet! | ||
@pytest.mark.parametrize('device', [ | ||
'cpu', | ||
pytest.param( | ||
'cuda', | ||
marks=pytest.mark.skipif( | ||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')), | ||
pytest.param( | ||
'musa', | ||
marks=pytest.mark.skipif( | ||
IS_MUSA_AVAILABLE, | ||
reason='TODO Not supported yet [email protected]')), | ||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')) | ||
]) | ||
def test_nms_quadri(self, device): | ||
from mmcv.ops import nms_quadri | ||
|
@@ -75,17 +72,13 @@ def test_nms_quadri(self, device): | |
assert np.allclose(dets.cpu().numpy()[:, :8], np_expect_dets) | ||
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds) | ||
|
||
# TODO:[email protected] musa not supported yet! | ||
@pytest.mark.parametrize('device', [ | ||
'cpu', | ||
pytest.param( | ||
'cuda', | ||
marks=pytest.mark.skipif( | ||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')), | ||
pytest.param( | ||
'musa', | ||
marks=pytest.mark.skipif( | ||
IS_MUSA_AVAILABLE, | ||
reason='TODO Not supported yet [email protected]')), | ||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')) | ||
]) | ||
def test_batched_nms(self, device): | ||
# test batched_nms with nms_quadri | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,17 +125,13 @@ def test_roialign_float(device, dtype): | |
_test_roialign_allclose(device=device, dtype=dtype) | ||
|
||
|
||
# TODO:[email protected] musa not supported yet! | ||
@pytest.mark.parametrize('device', [ | ||
'cpu', | ||
pytest.param( | ||
'cuda', | ||
marks=pytest.mark.skipif( | ||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')), | ||
pytest.param( | ||
'musa', | ||
marks=pytest.mark.skipif( | ||
IS_MUSA_AVAILABLE, | ||
reason='TODO:[email protected] not supported yet!')), | ||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')) | ||
]) | ||
def test_roialign_float64(device): | ||
_test_roialign_allclose(device=device, dtype=torch.double) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters