diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index baa6e97d04..617312145e 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -74,7 +74,7 @@ def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix): ) # ensure get_dim_out gives the correct shape t_des = tf.reshape(t_des, [1, natoms[0], obj.get_dim_out()]) - return [t_des], { + return [t_des, obj.get_rot_mat()], { t_coord: coords, t_type: atype, t_natoms: natoms, diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 92b2c6bd0b..db5fe4dae0 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -442,7 +442,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + return (ret[0], ret[1]) @property def rtol(self) -> float: diff --git a/source/tests/consistent/descriptor/test_hybrid.py b/source/tests/consistent/descriptor/test_hybrid.py index 2cbbd92d84..a042ec0311 100644 --- a/source/tests/consistent/descriptor/test_hybrid.py +++ b/source/tests/consistent/descriptor/test_hybrid.py @@ -168,4 +168,4 @@ def eval_jax(self, jax_obj: Any) -> Any: ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + return (ret[0], ret[1]) diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index af041c6cce..e1ef3b897d 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -340,7 +340,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + return (ret[0], ret[1]) @property def rtol(self) -> float: diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index 8838696108..140e09c544 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -259,7 +259,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + return (ret[0], ret[1]) @property def rtol(self) -> float: