diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 802183832e..dbeb86079d 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -97,7 +97,7 @@ def construct_dataset(system): global_rank = dist.get_rank() if dist.is_initialized() else 0 if global_rank == 0: log.info(f"Constructing DataLoaders from {len(systems)} systems") - with Pool(os.cpu_count()) as pool: + with Pool(env.NUM_WORKERS) as pool: self.systems = pool.map(construct_dataset, systems) else: self.systems = [None] * len(systems) # type: ignore