From e9672d079661b8141776d89636a1e212a972d0f9 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 27 Nov 2024 10:38:03 +0000
Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 deepmd/pt/utils/dataloader.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py
index 59dc2779fa..7e7e2cb12d 100644
--- a/deepmd/pt/utils/dataloader.py
+++ b/deepmd/pt/utils/dataloader.py
@@ -56,12 +56,14 @@ def setup_seed(seed) -> None:
     torch.cuda.manual_seed_all(seed)
     torch.backends.cudnn.deterministic = True
 
+
 def construct_dataset(system, type_map):
     return DeepmdDataSetForLoader(
         system=system,
         type_map=type_map,
     )
 
+
 class DpLoaderSet(Dataset):
     """A dataset for storing DataLoaders to multiple Systems.
 
@@ -97,7 +99,7 @@ def __init__(
         if len(systems) >= 100:
             log.info(f"Constructing DataLoaders from {len(systems)} systems")
 
-        construct_dataset_systems=partial(construct_dataset, type_map=type_map)
+        construct_dataset_systems = partial(construct_dataset, type_map=type_map)
 
         with Pool(
             os.cpu_count()
@@ -222,8 +224,10 @@ def run(self) -> None:
         # Signal the consumer we are done; this should not happen for DataLoader
         self._queue.put(StopIteration)
 
+
 QUEUESIZE = 32
 
+
 class BufferedIterator:
     def __init__(self, iterable) -> None:
         self._queue = Queue(QUEUESIZE)
@@ -242,7 +246,9 @@ def __next__(self):
         start_wait = time.time()
         item = self._queue.get()
         wait_time = time.time() - start_wait
-        if wait_time > 1.0: # Even for Multi-Task training, each step usually takes < 1s
+        if (
+            wait_time > 1.0
+        ):  # Even for Multi-Task training, each step usually takes < 1s
             log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")
         if isinstance(item, Exception):
             raise item