From 78b2a10477d13858b4c2ba0f70151bdbb1165b44 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 02:04:31 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 8ec8f7a5cb..822ce1be04 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -8,11 +8,10 @@ Optional, Union, ) + import numpy as np import torch -from deepmd.dpmodel.output_def import ( - FittingOutputDef, -) + from deepmd.pt.utils import ( AtomExcludeMask, ) @@ -35,6 +34,7 @@ log = logging.getLogger(__name__) + def make_stat_input( datasets, dataloaders, @@ -103,7 +103,9 @@ def process_with_new_frame(sys_indices, newele_counter, miss): newele_counter += 1 if newele_counter <= min_frames_per_element_forstat: frame_data = sys.__getitem__(frame) - assert miss in frame_data['atype'], f"Missing element '{miss}' not found in frame data." + assert ( + miss in frame_data["atype"] + ), f"Missing element '{miss}' not found in frame data." sys_stat_new = {} for dd in frame_data: if dd == "type": @@ -211,6 +213,7 @@ def finalize_stats(sys_stat): process_with_new_frame(sys_indices, newele_counter, miss) return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"],