diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 9b0b0a930e..16c7810727 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -656,8 +656,6 @@ def step(_step_id, task_key="Default") -> None: # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() - if not self.wrapper.training: - self.wrapper.train() if isinstance(self.lr_exp, dict): _lr = self.lr_exp[task_key] else: @@ -766,7 +764,7 @@ def fake_model(): if self.display_in_training and ( display_step_id % self.disp_freq == 0 or display_step_id == 1 ): - self.wrapper.eval() + self.wrapper.eval() # Will set to train mode before fininshing validation def log_loss_train(_loss, _more_loss, _task_key="Default"): results = {} @@ -872,6 +870,7 @@ def log_loss_valid(_task_key="Default"): learning_rate=None, ) ) + self.wrapper.train() current_time = time.time() train_time = current_time - self.t0 @@ -926,7 +925,7 @@ def log_loss_valid(_task_key="Default"): writer.add_scalar( f"{task_key}/{item}", more_loss[item], display_step_id ) - + self.wrapper.train() self.t0 = time.time() self.total_train_time = 0.0 for step_id in range(self.num_steps):