diff --git a/tests/embedding/test_embedding.py b/tests/embedding/test_embedding.py index c6c6d1baa..d6f4bcb8e 100644 --- a/tests/embedding/test_embedding.py +++ b/tests/embedding/test_embedding.py @@ -66,8 +66,8 @@ def test_embed_correct_order(self): device=device, ) - np.testing.assert_allclose(embeddings_1_worker, embeddings_4_worker, rtol=5e-5) - np.testing.assert_allclose(labels_1_worker, labels_4_worker, rtol=1e-5) + np.testing.assert_allclose(embeddings_1_worker, embeddings_4_worker, atol=5e-4) + np.testing.assert_allclose(labels_1_worker, labels_4_worker, atol=1e-5) self.assertListEqual(filenames_1_worker, filenames_4_worker) self.assertListEqual(filenames_1_worker, dataset.get_filenames())