Skip to content

Commit

Permalink
support LKF optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Oct 31, 2024
1 parent afd4746 commit 7b2476f
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).

### Highlighted features

- **interfaced with multiple backends**, including TensorFlow, PyTorch, JAX and Paddle the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with multiple backends**, including TensorFlow, PyTorch, JAX and Paddle, the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
- **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
- **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.
Expand Down
8 changes: 5 additions & 3 deletions deepmd/pd/optimizer/KFWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,17 @@ def update_energy(
mask = error < 0

error = error * update_prefactor
error[mask] = -1 * error[mask]
# error[mask] = -1 * error[mask]
error = _mask_update(error, mask, -error[mask])
error = error.mean()

if self.is_distributed:
dist.all_reduce(error)
error /= dist.get_world_size()

Etot_predict = update_prefactor * Etot_predict
Etot_predict[mask] = -Etot_predict[mask]
# Etot_predict[mask] = -Etot_predict[mask]
Etot_predict = _mask_update(Etot_predict, mask, -Etot_predict[mask])

Etot_predict.sum().backward()
error = error * math.sqrt(bs)
Expand All @@ -91,7 +93,7 @@ def update_force(
error_tmp = Force_label[:, index[i]] - force_predict[:, index[i]]
error_tmp = update_prefactor * error_tmp
mask = error_tmp < 0
error_tmp = _mask_update(error_tmp, mask, -1 * error_tmp[mask])
error_tmp = _mask_update(error_tmp, mask, -error_tmp[mask])
# error_tmp[mask] = -1 * error_tmp[mask]
error = error_tmp.mean() / natoms_sum

Expand Down
1 change: 1 addition & 0 deletions deepmd/pd/optimizer/LKF.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def __update(self, H, error, weights):
def set_grad_prefactor(self, grad_prefactor):
self.grad_prefactor = grad_prefactor

@paddle.no_grad()
def step(self, error):
params_packed_index = self._state.get("params_packed_index")

Expand Down
13 changes: 8 additions & 5 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
get_model,
get_zbl_model,
)
from deepmd.pd.optimizer import ( # LKFOptimizer,
from deepmd.pd.optimizer import (
KFOptimizerWrapper,
LKFOptimizer,
)
from deepmd.pd.train.wrapper import (
ModelWrapper,
Expand Down Expand Up @@ -601,10 +602,12 @@ def warm_up_linear(step, warmup_steps):
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.set_state_dict(optimizer_state_dict)
elif self.opt_type == "LKF":
raise NotImplementedError("LKF is not supported yet in Paddle backend.")
# self.optimizer = LKFOptimizer(
# [{'params': self.wrapper.parameters()}], 0.98, 0.99870, self.opt_param["kf_blocksize"]
# )
self.optimizer = LKFOptimizer(
[{"params": self.wrapper.parameters()}],
0.98,
0.99870,
self.opt_param["kf_blocksize"],
)
else:
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

Expand Down
1 change: 0 additions & 1 deletion source/tests/pd/test_LKF.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)


@unittest.skip("Paddle do not support LKF now")
class TestLKF(unittest.TestCase):
def test_lkf(self):
with open(str(Path(__file__).parent / "water/lkf.json")) as fin:
Expand Down

0 comments on commit 7b2476f

Please sign in to comment.