From 0dabf77f92e409967133f71e7c341a2c9721c73a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Jan 2025 05:50:17 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 80 +++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 9fb2db5c41..d246708e48 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import numpy as np -import torch from collections import ( defaultdict, ) @@ -11,13 +9,9 @@ Union, ) -from deepmd.dpmodel.output_def import ( - FittingOutputDef, -) - - import numpy as np import torch + from deepmd.pt.utils import ( AtomExcludeMask, ) @@ -40,7 +34,14 @@ log = logging.getLogger(__name__) -def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): + +def make_stat_input( + datasets, + dataloaders, + nbatches, + min_frames_per_element_forstat, + enable_element_completion=True, +): """Pack data for statistics. Element checking is only enabled with mixed_type. @@ -63,11 +64,13 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors if datasets[0].mixed_type: if enable_element_completion: log.info( - f'Element check enabled. ' - f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' + f"Element check enabled. " + f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." ) else: - log.info("Element completion is disabled. Skipping missing element handling.") + log.info( + "Element completion is disabled. Skipping missing element handling." + ) def process_batches(dataloader, sys_stat): """Process batches from a dataloader to collect statistics.""" @@ -93,8 +96,8 @@ def process_batches(dataloader, sys_stat): def process_with_new_frame(sys_indices, newele_counter): for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info["sys_index"] + frames = sys_info["frames"] sys = datasets[sys_index] for frame in frames: newele_counter += 1 @@ -126,7 +129,10 @@ def finalize_stats(sys_stat): for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): + elif sys_stat[key] is None or ( + isinstance(sys_stat[key], list) + and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) + ): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -137,7 +143,7 @@ def finalize_stats(sys_stat): with torch.device("cpu"): process_batches(dataloader, sys_stat) if datasets[0].mixed_type and enable_element_completion: - element_data = torch.cat(sys_stat['atype'], dim=0) + element_data = torch.cat(sys_stat["atype"], dim=0) collect_values = torch.unique(element_data.flatten(), sorted=True) for elem in collect_values.tolist(): frames_with_elem = torch.any(element_data == elem, dim=1) @@ -145,8 +151,8 @@ def finalize_stats(sys_stat): collect_ele[elem] += len(row_indices) finalize_stats(sys_stat) lst.append(sys_stat) - - #get frame index + + # get frame index if datasets[0].mixed_type and enable_element_completion: element_counts = dataset.get_frame_index() for elem, data in element_counts.items(): @@ -156,35 +162,37 @@ def finalize_stats(sys_stat): if elem not in global_element_counts: global_element_counts[elem] = {"count": 0, "indices": []} if count > min_frames_per_element_forstat: - global_element_counts[elem]["count"] += min_frames_per_element_forstat + global_element_counts[elem]["count"] += ( + min_frames_per_element_forstat + ) indices = indices[:min_frames_per_element_forstat] - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) else: - if global_element_counts[elem]["count"] >= min_frames_per_element_forstat: + if ( + global_element_counts[elem]["count"] + >= min_frames_per_element_forstat + ): pass else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) # Complement if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = data["count"] if indices_count < min_frames_per_element_forstat: log.warning( - f'The number of frames in your datasets with element {elem} is {indices_count}, ' - f'which is less than the required {min_frames_per_element_forstat}' + f"The number of frames in your datasets with element {elem} is {indices_count}, " + f"which is less than the required {min_frames_per_element_forstat}" ) collect_elements = collect_ele.keys() missing_elements = total_element_types - collect_elements @@ -194,14 +202,15 @@ def finalize_stats(sys_stat): collect_miss_element.add(ele) missing_elements.add(ele) for miss in missing_elements: - sys_indices = global_element_counts[miss].get('indices', []) + sys_indices = global_element_counts[miss].get("indices", []) if miss in collect_miss_element: newele_counter = collect_ele.get(miss, 0) else: newele_counter = 0 - process_with_new_frame(sys_indices,newele_counter) + process_with_new_frame(sys_indices, newele_counter) return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], @@ -523,6 +532,7 @@ def compute_output_stats( std_atom_e = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()} return bias_atom_e, std_atom_e + def compute_output_stats_global( sampled: list[dict], ntypes: int,