Skip to content

Commit

Permalink
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
Browse files Browse the repository at this point in the history
…devel
  • Loading branch information
SumGuo-88 committed Jan 6, 2025
2 parents 3ccb4b9 + 26205d7 commit f669ac5
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions source/tests/pt/test_make_stat_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def count_non_zero_elements(self, tensor, threshold=1e-8):
return torch.sum(torch.abs(tensor) > threshold).item()

def test_make_stat_input(self):
#3 frames would be count
# 3 frames would be count
lst = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
Expand All @@ -87,8 +87,13 @@ def test_make_stat_input(self):
)

def test_make_stat_input_nocomplete(self):
<<<<<<< HEAD
#missing element:13,31,37
#only one frame would be count
=======
# missing element:13,31,37
# only one frame would be count
>>>>>>> 26205d74c44201c56d39241b75e26ae381fcf67f

lst = make_stat_input(
datasets=self.datasets,
Expand Down Expand Up @@ -125,13 +130,11 @@ def test_bias(self):
bias_all, _ = compute_output_stats(lst_all, ntypes=57)
energy_ori = np.array(bias_ori.get("energy").cpu()).flatten()
energy_all = np.array(bias_all.get("energy").cpu()).flatten()

for i, (e_ori, e_all) in enumerate(zip(energy_ori, energy_all)):
if e_all == 0:
self.assertEqual(
e_ori,
0,
f"Index {i}: energy_all=0, but energy_ori={e_ori}"
e_ori, 0, f"Index {i}: energy_all=0, but energy_ori={e_ori}"
)
else:
if e_ori != 0:
Expand All @@ -140,8 +143,9 @@ def test_bias(self):
self.assertTrue(
rel_diff < 0.4,
f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, "
f"relative difference {rel_diff:.2%} is too large"
f"relative difference {rel_diff:.2%} is too large",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit f669ac5

Please sign in to comment.