Skip to content

Commit

Permalink
fix bug and add more ut
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Sep 10, 2024
1 parent 6dd5e05 commit ef86ce0
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 27 deletions.
65 changes: 40 additions & 25 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def parse_stru_pos(pos_line):
- `cs` or `constrain`: three numbers, which take value in 0 or 1, control the spin constraint of the atom.
- `lambda`: three numbers, control the lambda of the atom.
"""
pos_line = pos_line.split("#")[0] # remove comments
sline = pos_line.split()
pos = [float(i) for i in sline[:3]]
move = None
Expand All @@ -142,30 +143,36 @@ def parse_stru_pos(pos_line):
constrain = None
lambda1 = None
if len(sline) > 3:
mag_list = []
velocity_list = []
mag_list = None
velocity_list = None
move_list = []
angle1_list = []
angle2_list = []
constrain_list = []
lambda_list = []
angle1_list = None
angle2_list = None
constrain_list = None
lambda_list = None
label = "move"
for i in range(3, len(sline)):
# firstly read the label
if sline[i] == "m":
label = "move"
elif sline[i] in ["v", "vel", "velocity"]:
label = "velocity"
velocity_list = []
elif sline[i] in ["mag", "magmom"]:
label = "magmom"
mag_list = []
elif sline[i] == "angle1":
label = "angle1"
angle1_list = []
elif sline[i] == "angle2":
label = "angle2"
angle2_list = []
elif sline[i] in ["constrain", "sc"]:
label = "constrain"
constrain_list = []
elif sline[i] in ["lambda"]:
label = "lambda"
lambda_list = []

# the read the value to the list
elif label == "move":
Expand All @@ -183,45 +190,53 @@ def parse_stru_pos(pos_line):
elif label == "lambda":
lambda_list.append(float(sline[i]))

if len(move_list) > 0:
assert(len(move_list) == 3), f"should 3 numbers to define if atom can move, but got {len(move_list)}"
move = move_list
if move_list is not None:
if len(move_list) == 3:
move = move_list
else:
raise RuntimeError(f"Invalid setting of move: {pos_line}")

if len(velocity_list) > 0:
assert(len(velocity_list) == 3), f"should 3 numbers to define velocity, but got {len(velocity_list)}"
velocity = velocity_list
if velocity_list is not None:
if len(velocity_list) == 3:
velocity = velocity_list
else:
raise RuntimeError(f"Invalid setting of velocity: {pos_line}")

if len(mag_list) > 0:
if mag_list is not None:
if len(mag_list) == 3:
magmom = mag_list
elif len(mag_list) == 1:
magmom = mag_list[0]
else:
raise RuntimeError(f"Invalid magnetic moment {mag_list}")
raise RuntimeError(f"Invalid magnetic moment {pos_line}")

if len(angle1_list) > 0:
assert(len(angle1_list) == 1), f"should 1 number to define angle1, but got {len(angle1_list)}"
angle1 = angle1_list[0]
if angle1_list is not None:
if len(angle1_list) == 1:
angle1 = angle1_list[0]
else:
raise RuntimeError(f"Invalid angle1 {pos_line}")

if len(angle2_list) > 0:
assert(len(angle2_list) == 1), f"should 1 number to define angle2, but got {len(angle2_list)}"
angle2 = angle2_list[0]
if angle2_list is not None:
if len(angle2_list) == 1:
angle2 = angle2_list[0]
else:
raise RuntimeError(f"Invalid angle2 {pos_line}")

if len(constrain_list) > 0:
if constrain_list is not None:
if len(constrain_list) == 3:
constrain = constrain_list
elif len(constrain_list) == 1:
constrain = constrain_list[0]
else:
raise RuntimeError(f"Invalid constrain {constrain_list}")
raise RuntimeError(f"Invalid constrain {pos_line}")

if len(lambda_list) > 0:
if lambda_list is not None:
if len(lambda_list) == 3:
lambda1 = lambda_list
elif len(lambda_list) == 1:
lambda1 = lambda_list[0]
else:
raise RuntimeError(f"Invalid lambda {lambda_list}")
raise RuntimeError(f"Invalid lambda {pos_line}")

return pos,move,velocity,magmom,angle1,angle2,constrain,lambda1

Expand Down Expand Up @@ -761,7 +776,7 @@ def ndarray2list(i):
["1" if ii else "0" for ii in move[natom_tot]]
)
elif isinstance(ndarray2list(move[natom_tot]), (int, float, bool)):
iout += " " + "1 1 1" if move[natom_tot] else "0 0 0"
iout += " 1 1 1" if move[natom_tot] else " 0 0 0"
else:
iout += " 1 1 1"

Expand Down
4 changes: 2 additions & 2 deletions tests/abacus.scf/stru.ref
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ H
4
5.416431453540 4.011298860305 3.511161492417 1 1 1 mag 1.000000000000 1.000000000000 1.000000000000 angle2 90.000000000000 sc 1 lambda 0.400000000000 0.500000000000 0.600000000000
4.131588222365 4.706745191323 4.431136645083 1 1 1 mag 1.000000000000 angle1 100.000000000000 angle2 80.000000000000 sc 1 0 1 lambda 0.700000000000 0.800000000000 0.900000000000
5.630930319126 5.521640894956 4.4503565413030 0 0 mag 1.000000000000 angle1 90.000000000000 angle2 70.0000000000000
5.499851012568 4.003388899277 5.3426218426220 0 0 mag 1.000000000000 angle1 80.000000000000 sc 1
5.630930319126 5.521640894956 4.450356541303 0 0 0 mag 1.000000000000 angle1 90.000000000000 angle2 70.0000000000000
5.499851012568 4.003388899277 5.342621842622 0 0 0 mag 1.000000000000 angle1 80.000000000000 sc 1
49 changes: 49 additions & 0 deletions tests/test_abacus_stru_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,55 @@ def test_parse_stru_post(self):
self.assertEqual(constrain, None)
self.assertEqual(lambda1, None)

def test_parse_stru_error(self):
line = "1.0 2.0 3.0 1 1"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 mag 1.0 3.0 v 1 1 1"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 mag 1 2 3 4"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 v 1"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 v 1 1"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 v 1 1 1 1"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 1"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 angle1 "
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 angle1 1 2"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 1 1 1 angle2"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 angle2 1 2"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 sc"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 sc 1 2"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 lambda"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 lambda 1 2"
self.assertRaises(RuntimeError, parse_stru_pos, line), line

line = "1.0 2.0 3.0 lambda 1 2 3 4"
self.assertRaises(RuntimeError, parse_stru_pos, line), line


if __name__ == "__main__":
unittest.main()

0 comments on commit ef86ce0

Please sign in to comment.