Skip to content

Commit

Permalink
fix: add optional force check
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Jan 15, 2025
1 parent 6387419 commit 1ccb7df
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion dpdata/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def calculate(
self.results["energy"] = data["energies"][0]
# see https://gitlab.com/ase/ase/-/merge_requests/2485
self.results["free_energy"] = data["energies"][0]
self.results["forces"] = data["forces"][0]
if "forces" in data:
self.results["forces"] = data["forces"][0]
if "virials" in data:
self.results["virial"] = data["virials"][0].reshape(3, 3)

Expand Down
4 changes: 3 additions & 1 deletion dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
cell=data["cells"][ii],
)

results = {"energy": data["energies"][ii], "forces": data["forces"][ii]}
results = {"energy": data["energies"][ii]}
if "forces" in data:
results["forces"] = data["forces"][ii]
if "virials" in data:
# convert to GPa as this is ase convention
# v_pref = 1 * 1e4 / 1.602176621e6
Expand Down
4 changes: 3 additions & 1 deletion dpdata/plugins/pwmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ def from_labeled_system(
data["cells"],
data["coords"],
data["energies"],
data["forces"],
tmp_force,
tmp_virial,
) = dpdata.pwmat.movement.get_frames(
file_name, begin=begin, step=step, convergence_check=convergence_check
)
if tmp_force is not None:
data["forces"] = tmp_force
if tmp_virial is not None:
data["virials"] = tmp_virial
# scale virial to the unit of eV
Expand Down
4 changes: 3 additions & 1 deletion dpdata/plugins/vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def from_labeled_system(
data["cells"],
data["coords"],
data["energies"],
data["forces"],
tmp_force,
tmp_virial,
) = dpdata.vasp.outcar.get_frames(
file_name,
Expand All @@ -90,6 +90,8 @@ def from_labeled_system(
ml=ml,
convergence_check=convergence_check,
)
if tmp_force is not None:
data["forces"] = tmp_force
if tmp_virial is not None:
data["virials"] = tmp_virial
# scale virial to the unit of eV
Expand Down

0 comments on commit 1ccb7df

Please sign in to comment.