diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 88735ff18515..35c5b736bd96 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -148,8 +148,8 @@ def test_onnx_ops(self): @unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support") def test_softmax_bf16_large(self): - if not torch.cuda.is_available(): - # only test bf16 on cuda + if torch.version.cuda is None: + # Only run this test when CUDA is available, as on ROCm BF16 is not supported by MIOpen. return class Model(torch.nn.Module): @@ -175,7 +175,7 @@ def forward(self, input): data_ort.requires_grad = True ort_res = ort_model(input=data_ort) ort_res.backward(gradient=init_grad) - # compara result + # compare result torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)