diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f74c4769bf..b48c4993af 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -653,6 +653,12 @@ def run(self) -> None: prof.start() def step(_step_id, task_key="Default") -> None: + if self.multi_task: + model_index = dp_random.choice( + np.arange(self.num_model), + p=self.model_prob, + ) + task_key = self.model_keys[model_index] # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() @@ -929,24 +935,8 @@ def log_loss_valid(_task_key="Default"): self.t0 = time.time() self.total_train_time = 0.0 - for step_id in range(self.num_steps): - if step_id < self.start_step: - continue - if self.multi_task: - chosen_index_list = dp_random.choice( - np.arange( - self.num_model, dtype=np.int32 - ), # int32 should be enough for # models... - p=np.array(self.model_prob), - size=self.world_size, - replace=True, - ) - assert chosen_index_list.size == self.world_size - model_index = chosen_index_list[self.rank] - model_key = self.model_keys[model_index] - else: - model_key = "Default" - step(step_id, model_key) + for step_id in range(self.start_step, self.num_steps): + step(step_id) if JIT: break