From e88f18360dc4491a2a8c8e5b5a5e7c9b4aef2b09 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 13 Jan 2025 15:08:32 +0800 Subject: [PATCH] chore: test consistency of rotation matrix Signed-off-by: Jinzhe Zeng --- source/tests/consistent/descriptor/common.py | 2 +- source/tests/consistent/descriptor/test_dpa1.py | 2 +- source/tests/consistent/descriptor/test_hybrid.py | 2 +- source/tests/consistent/descriptor/test_se_atten_v2.py | 2 +- source/tests/consistent/descriptor/test_se_e2_a.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index baa6e97d04..617312145e 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -74,7 +74,7 @@ def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix): ) # ensure get_dim_out gives the correct shape t_des = tf.reshape(t_des, [1, natoms[0], obj.get_dim_out()]) - return [t_des], { + return [t_des, obj.get_rot_mat()], { t_coord: coords, t_type: atype, t_natoms: natoms, diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 92b2c6bd0b..db5fe4dae0 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -442,7 +442,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + return (ret[0], ret[1]) @property def rtol(self) -> float: diff --git a/source/tests/consistent/descriptor/test_hybrid.py b/source/tests/consistent/descriptor/test_hybrid.py index 2cbbd92d84..a042ec0311 100644 --- a/source/tests/consistent/descriptor/test_hybrid.py +++ b/source/tests/consistent/descriptor/test_hybrid.py @@ -168,4 +168,4 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + return (ret[0], ret[1]) diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index af041c6cce..e1ef3b897d 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -340,7 +340,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + return (ret[0], ret[1]) @property def rtol(self) -> float: diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index 8838696108..140e09c544 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -259,7 +259,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + return (ret[0], ret[1]) @property def rtol(self) -> float: