diff --git a/tests/test_utils/test_transformer.py b/tests/test_utils/test_transformer.py index 4ece46ab2..173e9ffd2 100644 --- a/tests/test_utils/test_transformer.py +++ b/tests/test_utils/test_transformer.py @@ -43,10 +43,8 @@ def test_transforms(): dtype=torch.float32) pts_torch = rtf.obb2poly(box_torch[None], version='full360')[0] box2_torch = rtf.poly2obb(pts_torch, version='full360')[0] - torch.testing.assert_close(box_torch, box2_torch, rtol=1e-4, atol=1e-4) + torch.norm(box_torch - box2_torch) < 1e-4 - # compatibility - torch.testing.assert_close( - box_torch, torch.from_numpy(box_np), rtol=1e-4, atol=1e-4) - torch.testing.assert_close( - pts_torch, torch.from_numpy(pts_np), rtol=1e-4, atol=1e-4) + # compatibility between numpy and torch implementations + torch.norm(box_torch - torch.from_numpy(box_np)) < 1e-4 + torch.norm(pts_torch - torch.from_numpy(pts_np)) < 1e-4