Skip to content

Commit

Permalink
fix condition block dtype mismatch in jit.save and enable 2 unitest
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Oct 31, 2024
1 parent 7b2476f commit 26047e9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
6 changes: 5 additions & 1 deletion deepmd/pd/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,11 @@ def build_multiple_neighbor_list(
).to(device=nlist.place)
# nb x nloc x nsel
nlist = paddle.concat([nlist, pad], axis=-1)
nsel = nsels[-1]
if paddle.is_tensor(nsel):
nsel = paddle.to_tensor(nsels[-1], dtype=nsel.dtype)
else:
nsel = nsels[-1]

# nb x nall x 3
coord1 = coord.reshape([nb, -1, 3])
nall = coord1.shape[1]
Expand Down
1 change: 0 additions & 1 deletion source/tests/pd/model/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def tearDown(self):
JITTest.tearDown(self)


@unittest.skip("var dtype int32/int64 confused in if block")
class TestEnergyModelDPA2(unittest.TestCase, JITTest):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
Expand Down
1 change: 0 additions & 1 deletion source/tests/pd/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def tearDown(self):
shutil.rmtree(f)


@unittest.skip("Paddle do not support MultiTaskSeA.")
class TestMultiTaskSeA(unittest.TestCase, MultiTaskTrainTest):
def setUp(self):
multitask_se_e2_a = deepcopy(multitask_template)
Expand Down

0 comments on commit 26047e9

Please sign in to comment.