Skip to content

Commit

Permalink
add huber loss
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 5, 2025
1 parent 930e581 commit 2964c01
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 3 deletions.
91 changes: 88 additions & 3 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,43 @@
)


def custom_huber_loss(predictions, targets, delta=1.0):
error = targets - predictions
abs_error = torch.abs(error)
quadratic_loss = 0.5 * torch.pow(error, 2)
linear_loss = delta * (abs_error - 0.5 * delta)
loss = torch.where(abs_error <= delta, quadratic_loss, linear_loss)
return torch.mean(loss)


def custom_step_huber_loss(predictions, targets, delta=1.0):
error = targets - predictions
abs_error = torch.abs(error)
abs_targets = torch.abs(targets)

# Define the different delta values based on the absolute value of targets
delta1 = delta
delta2 = 0.7 * delta
delta3 = 0.4 * delta
delta4 = 0.1 * delta

# Determine which delta to use based on the absolute value of targets
delta_values = torch.where(
abs_targets < 100,
delta1,
torch.where(
abs_targets < 200, delta2, torch.where(abs_targets < 300, delta3, delta4)
),
)

# Compute the quadratic and linear loss based on the dynamically selected delta values
quadratic_loss = 0.5 * torch.pow(error, 2)
linear_loss = delta_values * (abs_error - 0.5 * delta_values)
# Select the appropriate loss based on whether abs_error is less than or greater than delta_values
loss = torch.where(abs_error <= delta_values, quadratic_loss, linear_loss)
return torch.mean(loss)


class EnergyStdLoss(TaskLoss):
def __init__(
self,
Expand All @@ -41,6 +78,9 @@ def __init__(
numb_generalized_coord: int = 0,
use_l1_all: bool = False,
inference=False,
use_huber=False,
huber_delta=0.01,
torch_huber=False,
**kwargs,
) -> None:
r"""Construct a layer to compute loss on energy, force and virial.
Expand Down Expand Up @@ -118,6 +158,10 @@ def __init__(
)
self.use_l1_all = use_l1_all
self.inference = inference
self.huber = use_huber
self.huber_delta = huber_delta
self.torch_huber = torch_huber
self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta)

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on energy and force.
Expand Down Expand Up @@ -180,7 +224,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
if not self.huber:
loss += atom_norm * (pref_e * l2_ener_loss)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
atom_norm * model_pred["energy"],
atom_norm * label["energy"],
)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["energy"],
atom_norm * label["energy"],
delta=self.huber_delta,
)
loss += pref_e * l_huber_loss
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
Expand Down Expand Up @@ -233,7 +291,20 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss.detach(), find_force
)
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if not self.huber:
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
model_pred["force"], label["force"]
)
else:
l_huber_loss = custom_huber_loss(
force_pred.reshape(-1),
force_label.reshape(-1),
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = self.display_if_exist(
rmse_f.detach(), find_force
Expand Down Expand Up @@ -304,7 +375,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
if not self.huber:
loss += atom_norm * (pref_v * l2_virial_loss)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
atom_norm * model_pred["virial"],
atom_norm * label["virial"],
)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
atom_norm * label["virial"].reshape(-1),
delta=self.huber_delta,
)
loss += pref_v * l_huber_loss
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(
rmse_v.detach(), find_virial
Expand Down
18 changes: 18 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2506,6 +2506,24 @@ def loss_ener():
default=0,
doc=doc_numb_generalized_coord,
),
Argument(
"use_huber",
bool,
optional=True,
default=False,
),
Argument(
"huber_delta",
float,
optional=True,
default=0.01,
),
Argument(
"torch_huber",
bool,
optional=True,
default=False,
),
]


Expand Down

0 comments on commit 2964c01

Please sign in to comment.