From 1ccb7df802cd95cac1ee33a026a34f3aad4eddb8 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 15 Jan 2025 15:53:24 +0800 Subject: [PATCH] fix: add optional force check --- dpdata/ase_calculator.py | 3 ++- dpdata/plugins/ase.py | 4 +++- dpdata/plugins/pwmat.py | 4 +++- dpdata/plugins/vasp.py | 4 +++- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/dpdata/ase_calculator.py b/dpdata/ase_calculator.py index 1de760a5a..94f5073cb 100644 --- a/dpdata/ase_calculator.py +++ b/dpdata/ase_calculator.py @@ -62,7 +62,8 @@ def calculate( self.results["energy"] = data["energies"][0] # see https://gitlab.com/ase/ase/-/merge_requests/2485 self.results["free_energy"] = data["energies"][0] - self.results["forces"] = data["forces"][0] + if "forces" in data: + self.results["forces"] = data["forces"][0] if "virials" in data: self.results["virial"] = data["virials"][0].reshape(3, 3) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 83a6bcf95..bafd9e7e7 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -175,7 +175,9 @@ def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]: cell=data["cells"][ii], ) - results = {"energy": data["energies"][ii], "forces": data["forces"][ii]} + results = {"energy": data["energies"][ii]} + if "forces" in data: + results["forces"] = data["forces"][ii] if "virials" in data: # convert to GPa as this is ase convention # v_pref = 1 * 1e4 / 1.602176621e6 diff --git a/dpdata/plugins/pwmat.py b/dpdata/plugins/pwmat.py index ba3dab160..38a5bb297 100644 --- a/dpdata/plugins/pwmat.py +++ b/dpdata/plugins/pwmat.py @@ -31,11 +31,13 @@ def from_labeled_system( data["cells"], data["coords"], data["energies"], - data["forces"], + tmp_force, tmp_virial, ) = dpdata.pwmat.movement.get_frames( file_name, begin=begin, step=step, convergence_check=convergence_check ) + if tmp_force is not None: + data["forces"] = tmp_force if tmp_virial is not None: data["virials"] = tmp_virial # scale virial to the unit of eV diff --git a/dpdata/plugins/vasp.py b/dpdata/plugins/vasp.py index 0160bde29..da19199f0 100644 --- a/dpdata/plugins/vasp.py +++ b/dpdata/plugins/vasp.py @@ -81,7 +81,7 @@ def from_labeled_system( data["cells"], data["coords"], data["energies"], - data["forces"], + tmp_force, tmp_virial, ) = dpdata.vasp.outcar.get_frames( file_name, @@ -90,6 +90,8 @@ def from_labeled_system( ml=ml, convergence_check=convergence_check, ) + if tmp_force is not None: + data["forces"] = tmp_force if tmp_virial is not None: data["virials"] = tmp_virial # scale virial to the unit of eV