Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 27, 2024
1 parent 417da94 commit e9672d0
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ def setup_seed(seed) -> None:
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True


def construct_dataset(system, type_map):
return DeepmdDataSetForLoader(
system=system,
type_map=type_map,
)


class DpLoaderSet(Dataset):
"""A dataset for storing DataLoaders to multiple Systems.
Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(
if len(systems) >= 100:
log.info(f"Constructing DataLoaders from {len(systems)} systems")

construct_dataset_systems=partial(construct_dataset, type_map=type_map)
construct_dataset_systems = partial(construct_dataset, type_map=type_map)

with Pool(
os.cpu_count()
Expand Down Expand Up @@ -222,8 +224,10 @@ def run(self) -> None:
# Signal the consumer we are done; this should not happen for DataLoader
self._queue.put(StopIteration)


QUEUESIZE = 32


class BufferedIterator:
def __init__(self, iterable) -> None:
self._queue = Queue(QUEUESIZE)
Expand All @@ -242,7 +246,9 @@ def __next__(self):
start_wait = time.time()
item = self._queue.get()
wait_time = time.time() - start_wait
if wait_time > 1.0: # Even for Multi-Task training, each step usually takes < 1s
if (
wait_time > 1.0
): # Even for Multi-Task training, each step usually takes < 1s
log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")
if isinstance(item, Exception):
raise item
Expand Down

0 comments on commit e9672d0

Please sign in to comment.