Skip to content

Commit

Permalink
feat: dpmodel energy loss & consistent tests (#4531)
Browse files Browse the repository at this point in the history
Fix #4105. Fix #4429.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced a new energy loss calculation framework with support for
multiple machine learning backends.
- Added serialization and deserialization capabilities for loss modules.
- Added a new class `EnergyLoss` for computing energy-related loss
metrics.
  
- **Documentation**
  - Added SPDX license identifiers to multiple files.
  - Included docstrings for new classes and methods.

- **Tests**
- Implemented comprehensive test suite for energy loss functions across
different platforms (TensorFlow, PyTorch, Paddle, JAX).
- Introduced a new test class `TestEner` for evaluating energy loss
functions.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jan 7, 2025
1 parent 03819a0 commit dbdb9b9
Show file tree
Hide file tree
Showing 11 changed files with 931 additions and 0 deletions.
1 change: 1 addition & 0 deletions deepmd/dpmodel/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
339 changes: 339 additions & 0 deletions deepmd/dpmodel/loss/ener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.loss.loss import (
Loss,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.version import (
check_version_compatibility,
)


class EnergyLoss(Loss):
def __init__(
self,
starter_learning_rate: float,
start_pref_e: float = 0.02,
limit_pref_e: float = 1.00,
start_pref_f: float = 1000,
limit_pref_f: float = 1.00,
start_pref_v: float = 0.0,
limit_pref_v: float = 0.0,
start_pref_ae: float = 0.0,
limit_pref_ae: float = 0.0,
start_pref_pf: float = 0.0,
limit_pref_pf: float = 0.0,
relative_f: Optional[float] = None,
enable_atom_ener_coeff: bool = False,
start_pref_gf: float = 0.0,
limit_pref_gf: float = 0.0,
numb_generalized_coord: int = 0,
**kwargs,
) -> None:
self.starter_learning_rate = starter_learning_rate
self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
self.start_pref_f = start_pref_f
self.limit_pref_f = limit_pref_f
self.start_pref_v = start_pref_v
self.limit_pref_v = limit_pref_v
self.start_pref_ae = start_pref_ae
self.limit_pref_ae = limit_pref_ae
self.start_pref_pf = start_pref_pf
self.limit_pref_pf = limit_pref_pf
self.relative_f = relative_f
self.enable_atom_ener_coeff = enable_atom_ener_coeff
self.start_pref_gf = start_pref_gf
self.limit_pref_gf = limit_pref_gf
self.numb_generalized_coord = numb_generalized_coord
self.has_e = self.start_pref_e != 0.0 or self.limit_pref_e != 0.0
self.has_f = self.start_pref_f != 0.0 or self.limit_pref_f != 0.0
self.has_v = self.start_pref_v != 0.0 or self.limit_pref_v != 0.0
self.has_ae = self.start_pref_ae != 0.0 or self.limit_pref_ae != 0.0
self.has_pf = self.start_pref_pf != 0.0 or self.limit_pref_pf != 0.0
self.has_gf = self.start_pref_gf != 0.0 or self.limit_pref_gf != 0.0
if self.has_gf and self.numb_generalized_coord < 1:
raise RuntimeError(
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0"
)

def call(
self,
learning_rate: float,
natoms: int,
model_dict: dict[str, np.ndarray],
label_dict: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
"""Calculate loss from model results and labeled results."""
energy = model_dict["energy"]
force = model_dict["force"]
virial = model_dict["virial"]
atom_ener = model_dict["atom_ener"]
energy_hat = label_dict["energy"]
force_hat = label_dict["force"]
virial_hat = label_dict["virial"]
atom_ener_hat = label_dict["atom_ener"]
atom_pref = label_dict["atom_pref"]
find_energy = label_dict["find_energy"]
find_force = label_dict["find_force"]
find_virial = label_dict["find_virial"]
find_atom_ener = label_dict["find_atom_ener"]
find_atom_pref = label_dict["find_atom_pref"]
xp = array_api_compat.array_namespace(
energy,
force,
virial,
atom_ener,
energy_hat,
force_hat,
virial_hat,
atom_ener_hat,
atom_pref,
)

if self.enable_atom_ener_coeff:
# when ener_coeff (\nu) is defined, the energy is defined as
# E = \sum_i \nu_i E_i
# instead of the sum of atomic energies.
#
# A case is that we want to train reaction energy
# A + B -> C + D
# E = - E(A) - E(B) + E(C) + E(D)
# A, B, C, D could be put far away from each other
atom_ener_coeff = label_dict["atom_ener_coeff"]
atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener))
energy = xp.sum(atom_ener_coeff * atom_ener, 1)
if self.has_f or self.has_pf or self.relative_f or self.has_gf:
force_reshape = xp.reshape(force, [-1])
force_hat_reshape = xp.reshape(force_hat, [-1])
diff_f = force_hat_reshape - force_reshape
else:
diff_f = None

if self.relative_f is not None:
force_hat_3 = xp.reshape(force_hat, [-1, 3])
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f
diff_f_3 = xp.reshape(diff_f, [-1, 3])
diff_f_3 = diff_f_3 / norm_f
diff_f = xp.reshape(diff_f_3, [-1])

atom_norm = 1.0 / natoms
atom_norm_ener = 1.0 / natoms
lr_ratio = learning_rate / self.starter_learning_rate
pref_e = find_energy * (
self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * lr_ratio
)
pref_f = find_force * (
self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * lr_ratio
)
pref_v = find_virial * (
self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * lr_ratio
)
pref_ae = find_atom_ener * (
self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * lr_ratio
)
pref_pf = find_atom_pref * (
self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio
)

l2_loss = 0
more_loss = {}
if self.has_e:
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy)
if self.has_f:
l2_force_loss = xp.mean(xp.square(diff_f))
l2_loss += pref_f * l2_force_loss
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss, find_force
)
if self.has_v:
virial_reshape = xp.reshape(virial, [-1])
virial_hat_reshape = xp.reshape(virial_hat, [-1])
l2_virial_loss = xp.mean(
xp.square(virial_hat_reshape - virial_reshape),
)
l2_loss += atom_norm * (pref_v * l2_virial_loss)
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss, find_virial
)
if self.has_ae:
atom_ener_reshape = xp.reshape(atom_ener, [-1])
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1])
l2_atom_ener_loss = xp.mean(
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
)
l2_loss += pref_ae * l2_atom_ener_loss
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
l2_atom_ener_loss, find_atom_ener
)
if self.has_pf:
atom_pref_reshape = xp.reshape(atom_pref, [-1])
l2_pref_force_loss = xp.mean(
xp.multiply(xp.square(diff_f), atom_pref_reshape),
)
l2_loss += pref_pf * l2_pref_force_loss
more_loss["l2_pref_force_loss"] = self.display_if_exist(
l2_pref_force_loss, find_atom_pref
)
if self.has_gf:
find_drdq = label_dict["find_drdq"]
drdq = label_dict["drdq"]
force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3])
force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3])
drdq_reshape = xp.reshape(
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord]
)
gen_force_hat = xp.einsum(
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes
)
gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes)
diff_gen_force = gen_force_hat - gen_force
l2_gen_force_loss = xp.mean(xp.square(diff_gen_force))
pref_gf = find_drdq * (
self.limit_pref_gf
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio
)
l2_loss += pref_gf * l2_gen_force_loss
more_loss["l2_gen_force_loss"] = self.display_if_exist(
l2_gen_force_loss, find_drdq
)

self.l2_l = l2_loss
self.l2_more = more_loss
return l2_loss, more_loss

@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_e:
label_requirement.append(
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
)
)
if self.has_f:
label_requirement.append(
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_v:
label_requirement.append(
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
)
)
if self.has_ae:
label_requirement.append(
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_pf:
label_requirement.append(
DataRequirementItem(
"atom_pref",
ndof=1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
)
)
if self.has_gf > 0:
label_requirement.append(
DataRequirementItem(
"drdq",
ndof=self.numb_generalized_coord * 3,
atomic=True,
must=False,
high_prec=False,
)
)
if self.enable_atom_ener_coeff:
label_requirement.append(
DataRequirementItem(
"atom_ener_coeff",
ndof=1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
)
)
return label_requirement

def serialize(self) -> dict:
"""Serialize the loss module.
Returns
-------
dict
The serialized loss module
"""
return {
"@class": "EnergyLoss",
"@version": 1,
"starter_learning_rate": self.starter_learning_rate,
"start_pref_e": self.start_pref_e,
"limit_pref_e": self.limit_pref_e,
"start_pref_f": self.start_pref_f,
"limit_pref_f": self.limit_pref_f,
"start_pref_v": self.start_pref_v,
"limit_pref_v": self.limit_pref_v,
"start_pref_ae": self.start_pref_ae,
"limit_pref_ae": self.limit_pref_ae,
"start_pref_pf": self.start_pref_pf,
"limit_pref_pf": self.limit_pref_pf,
"relative_f": self.relative_f,
"enable_atom_ener_coeff": self.enable_atom_ener_coeff,
"start_pref_gf": self.start_pref_gf,
"limit_pref_gf": self.limit_pref_gf,
"numb_generalized_coord": self.numb_generalized_coord,
}

@classmethod
def deserialize(cls, data: dict) -> "Loss":
"""Deserialize the loss module.
Parameters
----------
data : dict
The serialized loss module
Returns
-------
Loss
The deserialized loss module
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
return cls(**data)
Loading

0 comments on commit dbdb9b9

Please sign in to comment.