Skip to content

Commit

Permalink
Perf: load data systems on rank 0
Browse files Browse the repository at this point in the history
  • Loading branch information
caic99 committed Dec 19, 2024
1 parent 104fc36 commit 5ad15d1
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import time
from multiprocessing.dummy import (
from multiprocessing import (
Pool,
)
from queue import (
Expand Down Expand Up @@ -88,25 +88,25 @@ def __init__(
systems = [os.path.join(systems, item) for item in file.keys()]

self.systems: list[DeepmdDataSetForLoader] = []
if len(systems) >= 100:
log.info(f"Constructing DataLoaders from {len(systems)} systems")

def construct_dataset(system):
if len(systems) >= 100:
log.info(f"Constructing DataLoaders from {len(systems)} systems")
return DeepmdDataSetForLoader(
system=system,
type_map=type_map,
)

with Pool(
os.cpu_count()
// (
int(os.environ["LOCAL_WORLD_SIZE"])
if dist.is_available() and dist.is_initialized()
else 1
)
) as pool:
self.systems = pool.map(construct_dataset, systems)

global_rank = dist.get_rank() if dist.is_initialized() else 0
if global_rank == 0:
with Pool(os.cpu_count()) as pool:
self.systems = pool.map(construct_dataset, systems)
if dist.is_initialized():
dist.broadcast_object_list(self.systems)
else:
self.systems = [None] * len(systems) # type: ignore
dist.broadcast_object_list(self.systems)
assert self.systems[-1] is not None
self.sampler_list: list[DistributedSampler] = []
self.index = []
self.total_batch = 0
Expand Down

0 comments on commit 5ad15d1

Please sign in to comment.