Skip to content

Commit

Permalink
Add assert to ensure that the new frame contains the required elements
Browse files Browse the repository at this point in the history
  • Loading branch information
SumGuo-88 committed Jan 6, 2025
1 parent 27999af commit 817d2ec
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import numpy as np
import torch
from collections import (
defaultdict,
)
Expand All @@ -10,6 +8,8 @@
Optional,
Union,
)
import numpy as np
import torch
from deepmd.dpmodel.output_def import (
FittingOutputDef,
)
Expand All @@ -35,7 +35,6 @@

log = logging.getLogger(__name__)


def make_stat_input(
datasets,
dataloaders,
Expand Down Expand Up @@ -95,7 +94,7 @@ def process_batches(dataloader, sys_stat):
else:
pass

def process_with_new_frame(sys_indices, newele_counter):
def process_with_new_frame(sys_indices, newele_counter, miss):
for sys_info in sys_indices:
sys_index = sys_info["sys_index"]
frames = sys_info["frames"]
Expand All @@ -104,6 +103,7 @@ def process_with_new_frame(sys_indices, newele_counter):
newele_counter += 1
if newele_counter <= min_frames_per_element_forstat:
frame_data = sys.__getitem__(frame)
assert miss in frame_data['atype'], f"Missing element '{miss}' not found in frame data."
sys_stat_new = {}
for dd in frame_data:
if dd == "type":
Expand Down Expand Up @@ -208,10 +208,9 @@ def finalize_stats(sys_stat):
newele_counter = collect_ele.get(miss, 0)
else:
newele_counter = 0
process_with_new_frame(sys_indices, newele_counter)
process_with_new_frame(sys_indices, newele_counter, miss)
return lst


def _restore_from_file(
stat_file_path: DPPath,
keys: list[str] = ["energy"],
Expand Down

0 comments on commit 817d2ec

Please sign in to comment.