Skip to content

Commit

Permalink
fix pd
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 5, 2025
1 parent 13f6b90 commit 219c417
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions source/tests/consistent/loss/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import paddle

from deepmd.pd.loss.ener import EnergyStdLoss as EnerLossPD
from deepmd.pd.utils.env import DEVICE as PD_DEVICE
else:
EnerLossPD = None
if INSTALLED_JAX:
Expand Down Expand Up @@ -213,8 +214,12 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return loss, more_loss

def eval_pd(self, pd_obj: Any) -> Any:
predict = {kk: paddle.asarray(vv) for kk, vv in self.predict.items()}
label = {kk: paddle.asarray(vv) for kk, vv in self.label.items()}
predict = {
kk: paddle.to_tensor(vv).to(PD_DEVICE) for kk, vv in self.predict.items()
}
label = {
kk: paddle.to_tensor(vv).to(PD_DEVICE) for kk, vv in self.label.items()
}
predict["atom_energy"] = predict.pop("atom_ener")
_, loss, more_loss = pd_obj(
{},
Expand Down

0 comments on commit 219c417

Please sign in to comment.