From e2e8b0a4a41fe7a8eae6c5da103e1824dc73a8cb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 8 Jan 2025 15:33:01 +0800 Subject: [PATCH 1/3] feat(tf): support tensor fitting with hybrid descriptor Fix #4527. Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/descriptor.py | 9 +- deepmd/tf/descriptor/hybrid.py | 18 +++ deepmd/tf/model/model.py | 5 + source/tests/tf/test_dipole_hybrid_descrpt.py | 145 ++++++++++++++++++ 4 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 source/tests/tf/test_dipole_hybrid_descrpt.py diff --git a/deepmd/tf/descriptor/descriptor.py b/deepmd/tf/descriptor/descriptor.py index dd86beb21e..bd1af8c72e 100644 --- a/deepmd/tf/descriptor/descriptor.py +++ b/deepmd/tf/descriptor/descriptor.py @@ -105,7 +105,8 @@ def get_dim_rot_mat_1(self) -> int: int the first dimension of the rotation matrix """ - raise NotImplementedError + # by default, no rotation matrix + return 0 def get_nlist(self) -> tuple[tf.Tensor, tf.Tensor, list[int], list[int]]: """Returns neighbor information. @@ -534,3 +535,9 @@ def serialize(self, suffix: str = "") -> dict: def input_requirement(self) -> list[DataRequirementItem]: """Return data requirements needed for the model input.""" return [] + + def get_rot_mat(self) -> tf.Tensor: + """Get rotational matrix.""" + nframes = tf.shape(self.dout)[0] + natoms = tf.shape(self.dout)[1] + return tf.zeros([nframes, natoms, 0], dtype=GLOBAL_TF_FLOAT_PRECISION) diff --git a/deepmd/tf/descriptor/hybrid.py b/deepmd/tf/descriptor/hybrid.py index 2ee35d9ebb..57c21f0ee6 100644 --- a/deepmd/tf/descriptor/hybrid.py +++ b/deepmd/tf/descriptor/hybrid.py @@ -492,3 +492,21 @@ def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid": if hasattr(ii, "type_embedding"): raise NotImplementedError("hybrid + type embedding is not supported") return obj + + def get_dim_rot_mat_1(self) -> int: + """Returns the first dimension of the rotation matrix. The rotation is of shape + dim_1 x 3. + + Returns + ------- + int + the first dimension of the rotation matrix + """ + return sum([ii.get_dim_rot_mat_1() for ii in self.descrpt_list]) + + def get_rot_mat(self) -> tf.Tensor: + """Get rotational matrix.""" + all_rot_mat = [] + for ii in self.descrpt_list: + all_rot_mat.append(ii.get_rot_mat()) + return tf.concat(all_rot_mat, axis=2) diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 8991bf1baf..3377ed2d51 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -668,6 +668,11 @@ def __init__( else: if fitting_net["type"] in ["dipole", "polar"]: fitting_net["embedding_width"] = self.descrpt.get_dim_rot_mat_1() + if fitting_net["embedding_width"] == 0: + raise ValueError( + "This descriptor cannot provide a rotation matrix " + "for a tensorial fitting." + ) self.fitting = Fitting( **fitting_net, descrpt=self.descrpt, diff --git a/source/tests/tf/test_dipole_hybrid_descrpt.py b/source/tests/tf/test_dipole_hybrid_descrpt.py new file mode 100644 index 0000000000..eb5b2b915a --- /dev/null +++ b/source/tests/tf/test_dipole_hybrid_descrpt.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.tf.descriptor.hybrid import ( + DescrptHybrid, +) +from deepmd.tf.env import ( + tf, +) +from deepmd.tf.fit import ( + DipoleFittingSeA, +) +from deepmd.tf.model import ( + DipoleModel, +) + +from .common import ( + DataSystem, + gen_data, + j_loader, +) + +GLOBAL_ENER_FLOAT_PRECISION = tf.float64 +GLOBAL_TF_FLOAT_PRECISION = tf.float64 +GLOBAL_NP_FLOAT_PRECISION = np.float64 + + +class TestModel(tf.test.TestCase): + def setUp(self) -> None: + gen_data() + + def test_model(self) -> None: + jfile = "polar_se_a.json" + jdata = j_loader(jfile) + + systems = jdata["systems"] + set_pfx = "set" + batch_size = jdata["batch_size"] + test_size = jdata["numb_test"] + batch_size = 1 + test_size = 1 + rcut = jdata["model"]["descriptor"]["rcut"] + + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None) + + test_data = data.get_test() + numb_test = 1 + + descrpt = DescrptHybrid( + list=[ + { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 1.8, + "rcut": 6.0, + "neuron": [2, 4, 8], + "resnet_dt": False, + "axis_neuron": 8, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 1.8, + "rcut": 6.0, + "neuron": [2, 4, 8], + "resnet_dt": False, + "axis_neuron": 8, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + { + "type": "se_e3", + "sel": [5, 5], + "rcut_smth": 1.8, + "rcut": 2.0, + "neuron": [2], + "resnet_dt": False, + "precision": "float64", + "seed": 1, + }, + ] + ) + jdata["model"]["fitting_net"].pop("type", None) + jdata["model"]["fitting_net"].pop("fit_diag", None) + jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() + jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() + jdata["model"]["fitting_net"]["embedding_width"] = descrpt.get_dim_rot_mat_1() + fitting = DipoleFittingSeA(**jdata["model"]["fitting_net"], uniform_seed=True) + model = DipoleModel(descrpt, fitting) + + # model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']]) + input_data = { + "coord": [test_data["coord"]], + "box": [test_data["box"]], + "type": [test_data["type"]], + "natoms_vec": [test_data["natoms_vec"]], + "default_mesh": [test_data["default_mesh"]], + "fparam": [test_data["fparam"]], + } + model._compute_input_stat(input_data) + + t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c") + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") + t_type = tf.placeholder(tf.int32, [None], name="i_type") + t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms") + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box") + t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") + is_training = tf.placeholder(tf.bool) + t_fparam = None + + model_pred = model.build( + t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + t_fparam, + suffix="dipole_se_a", + reuse=False, + ) + dipole = model_pred["dipole"] + gdipole = model_pred["global_dipole"] + force = model_pred["force"] + virial = model_pred["virial"] + atom_virial = model_pred["atom_virial"] + + feed_dict_test = { + t_prop_c: test_data["prop_c"], + t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]), + t_box: test_data["box"][:numb_test, :], + t_type: np.reshape(test_data["type"][:numb_test, :], [-1]), + t_natoms: test_data["natoms_vec"], + t_mesh: test_data["default_mesh"], + is_training: False, + } + + sess = self.cached_session().__enter__() + sess.run(tf.global_variables_initializer()) + [p, gp, f, v, av] = sess.run( + [dipole, gdipole, force, virial, atom_virial], feed_dict=feed_dict_test + ) From 6e01a3d2f3bd4b27706f7918c1c01b90fe172751 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 8 Jan 2025 16:44:18 +0800 Subject: [PATCH 2/3] Update source/tests/tf/test_dipole_hybrid_descrpt.py Signed-off-by: Jinzhe Zeng --- source/tests/tf/test_dipole_hybrid_descrpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/tf/test_dipole_hybrid_descrpt.py b/source/tests/tf/test_dipole_hybrid_descrpt.py index eb5b2b915a..7d49fc30b3 100644 --- a/source/tests/tf/test_dipole_hybrid_descrpt.py +++ b/source/tests/tf/test_dipole_hybrid_descrpt.py @@ -119,7 +119,7 @@ def test_model(self) -> None: t_box, t_mesh, t_fparam, - suffix="dipole_se_a", + suffix="dipole_hybrid", reuse=False, ) dipole = model_pred["dipole"] From 088450e0017a1bcafae8d4fa409db314d7c7a8fd Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 8 Jan 2025 16:46:34 +0800 Subject: [PATCH 3/3] Apply suggestions from code review Signed-off-by: Jinzhe Zeng --- source/tests/tf/test_dipole_hybrid_descrpt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/tests/tf/test_dipole_hybrid_descrpt.py b/source/tests/tf/test_dipole_hybrid_descrpt.py index 7d49fc30b3..cc500c43ac 100644 --- a/source/tests/tf/test_dipole_hybrid_descrpt.py +++ b/source/tests/tf/test_dipole_hybrid_descrpt.py @@ -35,8 +35,6 @@ def test_model(self) -> None: systems = jdata["systems"] set_pfx = "set" - batch_size = jdata["batch_size"] - test_size = jdata["numb_test"] batch_size = 1 test_size = 1 rcut = jdata["model"]["descriptor"]["rcut"]