From aa8d22e9d41685ff6f4d51f8c869f2608fa3b1ad Mon Sep 17 00:00:00 2001 From: "weihong.xu" Date: Fri, 17 May 2024 13:23:58 +0800 Subject: [PATCH 1/5] improve: ase try to get virials from different sources --- dpdata/plugins/ase.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 1d818483..e8ae4ab8 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -93,13 +93,21 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: "energies": np.array([energies]), "forces": np.array([forces]), } - try: - stress = atoms.get_stress(False) - except PropertyNotImplementedError: - pass - else: - virials = np.array([-atoms.get_volume() * stress]) + + # try to get virials from different sources + virials = atoms.info.get("virial") + if virials is None: + virials = atoms.info.get("virials") + if virials is None: + try: + stress = atoms.get_stress(False) + except PropertyNotImplementedError: + pass + else: + virials = np.array([-atoms.get_volume() * stress]) + if virials is not None: info_dict["virials"] = virials + return info_dict def from_multi_systems( From 42e31bf148a1aceab410c8a203ad24138233d8bc Mon Sep 17 00:00:00 2001 From: "weihong.xu" Date: Sun, 19 May 2024 15:54:40 +0800 Subject: [PATCH 2/5] ase plugin: add doc for virial sources --- dpdata/plugins/ase.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index e8ae4ab8..be3bf1de 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -167,13 +167,18 @@ def to_system(self, data, **kwargs): return structures def to_labeled_system(self, data, *args, **kwargs): - """Convert System to ASE Atoms object.""" + """Convert System to ASE Atoms object. + + Note that this method will try to load virials from the following sources: + - atoms.info['virial'] + - atoms.info['virials'] + - converted from stress tensor + """ from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator structures = [] species = [data["atom_names"][tt] for tt in data["atom_types"]] - for ii in range(data["coords"].shape[0]): structure = Atoms( symbols=species, From 28670397738aa966ea721891c82585441c2f6acd Mon Sep 17 00:00:00 2001 From: "weihong.xu" Date: Sun, 19 May 2024 15:58:32 +0800 Subject: [PATCH 3/5] ase plugin: add doc for virial sources --- dpdata/plugins/ase.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index be3bf1de..0b02139d 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -79,6 +79,12 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: RuntimeError ASE will raise RuntimeError if the atoms does not have a calculator + + + Note that this method will try to load virials from the following sources: + - atoms.info['virial'] + - atoms.info['virials'] + - converted from stress tensor """ from ase.calculators.calculator import PropertyNotImplementedError @@ -167,13 +173,7 @@ def to_system(self, data, **kwargs): return structures def to_labeled_system(self, data, *args, **kwargs): - """Convert System to ASE Atoms object. - - Note that this method will try to load virials from the following sources: - - atoms.info['virial'] - - atoms.info['virials'] - - converted from stress tensor - """ + """Convert System to ASE Atoms object. """ from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator From 451bd0ddc30efc882b70f70ccfcccd4c74e62c11 Mon Sep 17 00:00:00 2001 From: "weihong.xu" Date: Sun, 19 May 2024 15:59:01 +0800 Subject: [PATCH 4/5] ase plugin: add doc for virial sources --- dpdata/plugins/ase.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 0b02139d..59022630 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -62,6 +62,11 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: """Convert ase.Atoms to a LabeledSystem. Energies and forces are calculated by the calculator. + Note that this method will try to load virials from the following sources: + - atoms.info['virial'] + - atoms.info['virials'] + - converted from stress tensor + Parameters ---------- atoms : ase.Atoms @@ -79,12 +84,6 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: RuntimeError ASE will raise RuntimeError if the atoms does not have a calculator - - - Note that this method will try to load virials from the following sources: - - atoms.info['virial'] - - atoms.info['virials'] - - converted from stress tensor """ from ase.calculators.calculator import PropertyNotImplementedError From 774fb7d60ce31426c4f485d6bb729e5d7c0ba3cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 19 May 2024 08:00:32 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/plugins/ase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 59022630..3ee35c28 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -172,7 +172,7 @@ def to_system(self, data, **kwargs): return structures def to_labeled_system(self, data, *args, **kwargs): - """Convert System to ASE Atoms object. """ + """Convert System to ASE Atoms object.""" from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator