Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 5, 2025
1 parent 379d4ad commit 0dabf77
Showing 1 changed file with 45 additions and 35 deletions.
80 changes: 45 additions & 35 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import numpy as np
import torch
from collections import (
defaultdict,
)
Expand All @@ -11,13 +9,9 @@
Union,
)

from deepmd.dpmodel.output_def import (
FittingOutputDef,
)


import numpy as np
import torch

from deepmd.pt.utils import (
AtomExcludeMask,
)
Expand All @@ -40,7 +34,14 @@

log = logging.getLogger(__name__)

def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True):

def make_stat_input(
datasets,
dataloaders,
nbatches,
min_frames_per_element_forstat,
enable_element_completion=True,
):
"""Pack data for statistics.
Element checking is only enabled with mixed_type.
Expand All @@ -63,11 +64,13 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors
if datasets[0].mixed_type:
if enable_element_completion:
log.info(
f'Element check enabled. '
f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.'
f"Element check enabled. "
f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}."
)
else:
log.info("Element completion is disabled. Skipping missing element handling.")
log.info(
"Element completion is disabled. Skipping missing element handling."
)

def process_batches(dataloader, sys_stat):
"""Process batches from a dataloader to collect statistics."""
Expand All @@ -93,8 +96,8 @@ def process_batches(dataloader, sys_stat):

def process_with_new_frame(sys_indices, newele_counter):
for sys_info in sys_indices:
sys_index = sys_info['sys_index']
frames = sys_info['frames']
sys_index = sys_info["sys_index"]
frames = sys_info["frames"]
sys = datasets[sys_index]
for frame in frames:
newele_counter += 1
Expand Down Expand Up @@ -126,7 +129,10 @@ def finalize_stats(sys_stat):
for key in sys_stat:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)):
elif sys_stat[key] is None or (
isinstance(sys_stat[key], list)
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
):
sys_stat[key] = None
elif isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
Expand All @@ -137,16 +143,16 @@ def finalize_stats(sys_stat):
with torch.device("cpu"):
process_batches(dataloader, sys_stat)
if datasets[0].mixed_type and enable_element_completion:
element_data = torch.cat(sys_stat['atype'], dim=0)
element_data = torch.cat(sys_stat["atype"], dim=0)
collect_values = torch.unique(element_data.flatten(), sorted=True)
for elem in collect_values.tolist():
frames_with_elem = torch.any(element_data == elem, dim=1)
row_indices = torch.where(frames_with_elem)[0]
collect_ele[elem] += len(row_indices)
finalize_stats(sys_stat)
lst.append(sys_stat)
#get frame index

# get frame index
if datasets[0].mixed_type and enable_element_completion:
element_counts = dataset.get_frame_index()
for elem, data in element_counts.items():
Expand All @@ -156,35 +162,37 @@ def finalize_stats(sys_stat):
if elem not in global_element_counts:
global_element_counts[elem] = {"count": 0, "indices": []}
if count > min_frames_per_element_forstat:
global_element_counts[elem]["count"] += min_frames_per_element_forstat
global_element_counts[elem]["count"] += (
min_frames_per_element_forstat
)
indices = indices[:min_frames_per_element_forstat]
global_element_counts[elem]["indices"].append({
"sys_index": sys_index,
"frames": indices
})
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
else:
global_element_counts[elem]["count"] += count
global_element_counts[elem]["indices"].append({
"sys_index": sys_index,
"frames": indices
})
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
else:
if global_element_counts[elem]["count"] >= min_frames_per_element_forstat:
if (
global_element_counts[elem]["count"]
>= min_frames_per_element_forstat
):
pass
else:
global_element_counts[elem]["count"] += count
global_element_counts[elem]["indices"].append({
"sys_index": sys_index,
"frames": indices
})
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
# Complement
if datasets[0].mixed_type and enable_element_completion:
for elem, data in global_element_counts.items():
indices_count = data["count"]
if indices_count < min_frames_per_element_forstat:
log.warning(
f'The number of frames in your datasets with element {elem} is {indices_count}, '
f'which is less than the required {min_frames_per_element_forstat}'
f"The number of frames in your datasets with element {elem} is {indices_count}, "
f"which is less than the required {min_frames_per_element_forstat}"
)
collect_elements = collect_ele.keys()
missing_elements = total_element_types - collect_elements
Expand All @@ -194,14 +202,15 @@ def finalize_stats(sys_stat):
collect_miss_element.add(ele)
missing_elements.add(ele)
for miss in missing_elements:
sys_indices = global_element_counts[miss].get('indices', [])
sys_indices = global_element_counts[miss].get("indices", [])
if miss in collect_miss_element:
newele_counter = collect_ele.get(miss, 0)
else:
newele_counter = 0
process_with_new_frame(sys_indices,newele_counter)
process_with_new_frame(sys_indices, newele_counter)
return lst


def _restore_from_file(
stat_file_path: DPPath,
keys: list[str] = ["energy"],
Expand Down Expand Up @@ -523,6 +532,7 @@ def compute_output_stats(
std_atom_e = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()}
return bias_atom_e, std_atom_e


def compute_output_stats_global(
sampled: list[dict],
ntypes: int,
Expand Down

0 comments on commit 0dabf77

Please sign in to comment.